[RSS] Replace scipy KDTree with Euclidean distance (#245)
* Replace scipy KDTree with Euclidean distance * fixed star expression and missing import * Version number Co-authored-by: aikaterna <20862007+aikaterna@users.noreply.github.com>
This commit is contained in:
76
rss/color.py
76
rss/color.py
@@ -1,9 +1,70 @@
|
||||
from math import sqrt
|
||||
import discord
|
||||
import re
|
||||
from scipy.spatial import KDTree
|
||||
import webcolors
|
||||
|
||||
|
||||
_DISCORD_COLOURS = {
|
||||
discord.Color.teal().to_rgb(): 'teal',
|
||||
discord.Color.dark_teal().to_rgb(): 'dark_teal',
|
||||
discord.Color.green().to_rgb(): 'green',
|
||||
discord.Color.dark_green().to_rgb(): 'dark_green',
|
||||
discord.Color.blue().to_rgb(): 'blue',
|
||||
discord.Color.dark_blue().to_rgb(): 'dark_blue',
|
||||
discord.Color.purple().to_rgb(): 'purple',
|
||||
discord.Color.dark_purple().to_rgb(): 'dark_purple',
|
||||
discord.Color.magenta().to_rgb(): 'magenta',
|
||||
discord.Color.dark_magenta().to_rgb(): 'dark_magenta',
|
||||
discord.Color.gold().to_rgb(): 'gold',
|
||||
discord.Color.dark_gold().to_rgb(): 'dark_gold',
|
||||
discord.Color.orange().to_rgb(): 'orange',
|
||||
discord.Color.dark_orange().to_rgb(): 'dark_orange',
|
||||
discord.Color.red().to_rgb(): 'red',
|
||||
discord.Color.dark_red().to_rgb(): 'dark_red',
|
||||
discord.Color.lighter_grey().to_rgb(): 'lighter_grey',
|
||||
discord.Color.light_grey().to_rgb(): 'light_grey',
|
||||
discord.Color.dark_grey().to_rgb(): 'dark_grey',
|
||||
discord.Color.darker_grey().to_rgb(): 'darker_grey',
|
||||
discord.Color.blurple().to_rgb(): 'old_blurple',
|
||||
discord.Color(0x4a90e2).to_rgb(): 'new_blurple',
|
||||
discord.Color.greyple().to_rgb(): 'greyple',
|
||||
discord.Color.dark_theme().to_rgb(): 'discord_dark_theme'
|
||||
}
|
||||
|
||||
_RGB_NAME_MAP = {webcolors.hex_to_rgb(hexcode): name for hexcode, name in webcolors.css3_hex_to_names.items()}
|
||||
_RGB_NAME_MAP.update(_DISCORD_COLOURS)
|
||||
|
||||
|
||||
def _distance(point_a: tuple, point_b: tuple):
|
||||
"""
|
||||
Euclidean distance between two points using rgb values as the metric space.
|
||||
"""
|
||||
# rgb values
|
||||
x1, y1, z1 = point_a
|
||||
x2, y2, z2 = point_b
|
||||
|
||||
# distances
|
||||
dx = x1 - x2
|
||||
dy = y1 - y2
|
||||
dz = z1 - z2
|
||||
|
||||
# final distance
|
||||
return sqrt(dx**2 + dy**2 + dz**2)
|
||||
|
||||
def _linear_nearest_neighbour(all_points: list, pivot: tuple):
|
||||
"""
|
||||
Check distance against all points from the pivot and return the distance and nearest point.
|
||||
"""
|
||||
best_dist = None
|
||||
nearest = None
|
||||
for point in all_points:
|
||||
dist = _distance(point, pivot)
|
||||
if best_dist is None or dist < best_dist:
|
||||
best_dist = dist
|
||||
nearest = point
|
||||
return best_dist, nearest
|
||||
|
||||
|
||||
class Color:
|
||||
"""Helper for color handling."""
|
||||
|
||||
@@ -43,17 +104,10 @@ class Color:
|
||||
hex_code = await self._hex_validator(hex_code)
|
||||
rgb_tuple = await self._hex_to_rgb(hex_code)
|
||||
|
||||
names = []
|
||||
positions = []
|
||||
positions = list(_RGB_NAME_MAP.keys())
|
||||
dist, nearest = _linear_nearest_neighbour(positions, rgb_tuple)
|
||||
|
||||
for hex, name in webcolors.css3_hex_to_names.items():
|
||||
names.append(name)
|
||||
positions.append(webcolors.hex_to_rgb(hex))
|
||||
|
||||
spacedb = KDTree(positions)
|
||||
dist, index = spacedb.query(rgb_tuple)
|
||||
|
||||
return names[index]
|
||||
return _RGB_NAME_MAP[nearest]
|
||||
|
||||
async def _hex_to_rgb(self, hex_code: str):
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,6 @@
|
||||
"description": "Read RSS feeds",
|
||||
"tags": ["rss"],
|
||||
"permissions": ["embed_links"],
|
||||
"requirements": ["bs4", "feedparser>=6.0.0", "scipy", "webcolors==1.3"],
|
||||
"requirements": ["bs4", "feedparser>=6.0.0", "webcolors==1.3"],
|
||||
"min_bot_version" : "3.4.0"
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ from .tag_type import INTERNAL_TAGS, VALID_IMAGES, TagType
|
||||
log = logging.getLogger("red.aikaterna.rss")
|
||||
|
||||
|
||||
__version__ = "1.4.5"
|
||||
__version__ = "1.5.0"
|
||||
|
||||
|
||||
class RSS(commands.Cog):
|
||||
|
||||
Reference in New Issue
Block a user