From 5d44f55908fb3068b5315c908be2afb17d2eeed3 Mon Sep 17 00:00:00 2001 From: Kowlin Date: Sat, 3 Apr 2021 02:32:46 +0200 Subject: [PATCH] Refactor code to be more managable (#220) And add in serverchart --- chatchart/chatchart.py | 255 ++++++++++++++++++++++++++++------------- 1 file changed, 178 insertions(+), 77 deletions(-) diff --git a/chatchart/chatchart.py b/chatchart/chatchart.py index 44d3357..3decdec 100644 --- a/chatchart/chatchart.py +++ b/chatchart/chatchart.py @@ -8,7 +8,7 @@ import asyncio import discord import heapq from io import BytesIO -from typing import Optional +from typing import List, Optional, Tuple, Union import matplotlib @@ -38,18 +38,62 @@ class Chatchart(commands.Cog): self.config.register_global(**default_global) @staticmethod - async def create_chart(top, others, channel): + def calculate_member_perc(history: List[discord.Message]) -> dict: + """Calculate the member count from the message history""" + msg_data = {"total_count": 0, "users": {}} + for msg in history: + # Name formatting + if len(msg.author.display_name) >= 20: + short_name = "{}...".format(msg.author.display_name[:20]).replace("$", "\\$") + else: + short_name = msg.author.display_name.replace("$", "\\$").replace("_", "\\_ ").replace("*", "\\*") + whole_name = "{}#{}".format(short_name, msg.author.discriminator) + if msg.author.bot: + pass + elif whole_name in msg_data["users"]: + msg_data["users"][whole_name]["msgcount"] += 1 + msg_data["total_count"] += 1 + else: + msg_data["users"][whole_name] = {} + msg_data["users"][whole_name]["msgcount"] = 1 + msg_data["total_count"] += 1 + return msg_data + + @staticmethod + def calculate_top(msg_data: dict) -> Tuple[list, int]: + """Calculate the top 20 from the message data package""" + for usr in msg_data["users"]: + pd = float(msg_data["users"][usr]["msgcount"]) / float(msg_data["total_count"]) + msg_data["users"][usr]["percent"] = round(pd * 100, 1) + top_twenty = heapq.nlargest( + 20, + [ + (x, msg_data["users"][x][y]) + for x in msg_data["users"] + for y in msg_data["users"][x] + if (y == "percent" and msg_data["users"][x][y] > 0) + ], + key=lambda x: x[1], + ) + others = 100 - sum(x[1] for x in top_twenty) + return top_twenty, others + + @staticmethod + async def create_chart(top, others, channel_or_guild: Union[discord.Guild, discord.TextChannel]): plt.clf() sizes = [x[1] for x in top] labels = ["{} {:g}%".format(x[0], x[1]) for x in top] if len(top) >= 20: sizes = sizes + [others] labels = labels + ["Others {:g}%".format(others)] - if len(channel.name) >= 19: - channel_name = "{}...".format(channel.name[:19]) + if len(channel_or_guild.name) >= 19: + if isinstance(channel_or_guild, discord.Guild): + channel_or_guild_name = "{}...".format(channel_or_guild.name[:19]) + else: + channel_or_guild_name = "#{}...".format(channel_or_guild.name[:19]) else: - channel_name = channel.name - title = plt.title("Stats in #{}".format(channel_name), color="white") + channel_or_guild_name = channel_or_guild.name + title = plt.title("Stats in {}".format(channel_or_guild_name), color="white") title.set_va("top") title.set_ha("center") plt.gca().axis("equal") @@ -92,106 +136,163 @@ class Chatchart(commands.Cog): image_object.seek(0) return image_object + async def fetch_channel_history( + self, + channel: discord.TextChannel, + animation_message: discord.Message, + messages: int + ) -> List[discord.Message]: + """Fetch the history of a channel while displaying an status message with it""" + animation_message_deleted = False + history = [] + history_counter = 0 + async for msg in channel.history(limit=messages): + history.append(msg) + history_counter += 1 + await asyncio.sleep(0.005) + if history_counter % 250 == 0: + new_embed = discord.Embed( + title=f"Fetching messages from #{channel.name}", + description=f"This might take a while...\n{history_counter}/{messages} messages gathered", + colour=await self.bot.get_embed_colour(location=channel), + ) + await channel.trigger_typing() + if animation_message_deleted is False: + try: + await animation_message.edit(embed=new_embed) + except discord.NotFound: + animation_message_deleted = True + return history + @commands.guild_only() @commands.command() - @commands.cooldown(1, 10, commands.BucketType.channel) - @commands.max_concurrency(1, commands.BucketType.channel) + @commands.cooldown(1, 10, commands.BucketType.guild) + @commands.max_concurrency(1, commands.BucketType.guild) @commands.bot_has_permissions(attach_files=True) - async def chatchart(self, ctx, channel: Optional[discord.TextChannel] = None, messages=5000): + async def chatchart(self, ctx, channel: Optional[discord.TextChannel] = None, messages:int = 5000): """ Generates a pie chart, representing the last 5000 messages in the specified channel. """ if channel is None: channel = ctx.channel - deny = await self.config.guild(ctx.guild).channel_deny() - if channel.id in deny: + + # --- Early terminations + if channel.permissions_for(ctx.message.author).read_messages is False: + return await ctx.send("You're not allowed to access that channel.") + if channel.permissions_for(ctx.guild.me).read_messages is False: + return await ctx.send("I cannot read the history of that channel.") + blacklisted_channels = await self.config.guild(ctx.guild).channel_deny() + if channel.id in blacklisted_channels: return await ctx.send(f"I am not allowed to create a chatchart of {channel.mention}.") + if messages < 5: + return await ctx.send("Don't be silly.") message_limit = await self.config.limit() if (message_limit != 0) and (messages > message_limit): messages = message_limit - e = discord.Embed( - description="This might take a while...", colour=await self.bot.get_embed_colour(location=channel) + embed = discord.Embed( + title=f"Fetching messages from #{channel.name}", + description="This might take a while...", + colour=await self.bot.get_embed_colour(location=channel) ) - em = await ctx.send(embed=e) - - history = [] - history_counter = 0 - - if not channel.permissions_for(ctx.message.author).read_messages == True: - try: - await em.delete() - except discord.NotFound: - pass - return await ctx.send("You're not allowed to access that channel.") + loading_message = await ctx.send(embed=embed) try: - async for msg in channel.history(limit=messages): - history.append(msg) - history_counter += 1 - await asyncio.sleep(0.005) - if history_counter % 250 == 0: - new_embed = discord.Embed( - description=f"This might take a while...\n{history_counter}/{messages} messages gathered", - colour=await self.bot.get_embed_colour(location=channel), - ) - if channel.permissions_for(ctx.guild.me).send_messages: - await channel.trigger_typing() - try: - await em.edit(embed=new_embed) - except discord.NotFound: - pass # for cases where the embed was deleted preventing the edit - + history = await self.fetch_channel_history(channel, loading_message, messages) except discord.errors.Forbidden: try: - await em.delete() + await loading_message.delete() except discord.NotFound: pass return await ctx.send("No permissions to read that channel.") - msg_data = {"total count": 0, "users": {}} - for msg in history: - if len(msg.author.display_name) >= 20: - short_name = "{}...".format(msg.author.display_name[:20]).replace("$", "\\$") - else: - short_name = msg.author.display_name.replace("$", "\\$").replace("_", "\\_ ").replace("*", "\\*") - whole_name = "{}#{}".format(short_name, msg.author.discriminator) - if msg.author.bot: - pass - elif whole_name in msg_data["users"]: - msg_data["users"][whole_name]["msgcount"] += 1 - msg_data["total count"] += 1 - else: - msg_data["users"][whole_name] = {} - msg_data["users"][whole_name]["msgcount"] = 1 - msg_data["total count"] += 1 - - if msg_data["users"] == {}: + msg_data = self.calculate_member_perc(history) + # If no members are found. + if len(msg_data["users"]) == 0: try: - await em.delete() + await loading_message.delete() except discord.NotFound: pass return await ctx.send(f"Only bots have sent messages in {channel.mention} or I can't read message history.") - for usr in msg_data["users"]: - pd = float(msg_data["users"][usr]["msgcount"]) / float(msg_data["total count"]) - msg_data["users"][usr]["percent"] = round(pd * 100, 1) - - top_ten = heapq.nlargest( - 20, - [ - (x, msg_data["users"][x][y]) - for x in msg_data["users"] - for y in msg_data["users"][x] - if (y == "percent" and msg_data["users"][x][y] > 0) - ], - key=lambda x: x[1], - ) - others = 100 - sum(x[1] for x in top_ten) - chart = await self.create_chart(top_ten, others, channel) + top_twenty, others = self.calculate_top(msg_data) + chart = await self.create_chart(top_twenty, others, channel) try: - await em.delete() + await loading_message.delete() + except discord.NotFound: + pass + await ctx.send(file=discord.File(chart, "chart.png")) + + @checks.mod_or_permissions(manage_guild=True) + @commands.guild_only() + @commands.command(aliases=["guildchart"]) + @commands.cooldown(1, 30, commands.BucketType.guild) + @commands.max_concurrency(1, commands.BucketType.guild) + @commands.bot_has_permissions(attach_files=True) + async def serverchart(self, ctx: commands.Context, messages: int = 1000): + """ + Generates a pie chart, representing the last 1000 messages from every allowed channel in the server. + + As example: + For each channel that the bot is allowed to scan. It will take the last 1000 messages from each channel. + And proceed to build a chart out of that. + """ + if messages < 5: + return await ctx.send("Don't be silly.") + channel_list = [] + blacklisted_channels = await self.config.guild(ctx.guild).channel_deny() + for channel in ctx.guild.text_channels: + channel: discord.TextChannel + if channel.id in blacklisted_channels: + continue + if channel.permissions_for(ctx.message.author).read_messages is False: + continue + if channel.permissions_for(ctx.guild.me).read_messages is False: + continue + channel_list.append(channel) + + if len(channel_list) == 0: + return await ctx.send("There are no channels to read... This should theoretically never happen.") + + embed = discord.Embed( + description="Fetching messages from the entire server this **will** take a while.", + colour=await self.bot.get_embed_colour(location=ctx.channel), + ) + global_fetch_message = await ctx.send(embed=embed) + global_history = [] + + for channel in channel_list: + embed = discord.Embed( + title=f"Fetching messages from #{channel.name}", + description="This might take a while...", + colour=await self.bot.get_embed_colour(location=channel) + ) + loading_message = await ctx.send(embed=embed) + try: + history = await self.fetch_channel_history(channel, loading_message, messages) + global_history += history + await loading_message.delete() + except discord.errors.Forbidden: + try: + await loading_message.delete() + except discord.NotFound: + continue + + msg_data = self.calculate_member_perc(global_history) + # If no members are found. + if len(msg_data["users"]) == 0: + try: + await global_fetch_message.delete() + except discord.NotFound: + pass + return await ctx.send(f"Only bots have sent messages in this server... Wauw...") + + top_twenty, others = self.calculate_top(msg_data) + chart = await self.create_chart(top_twenty, others, ctx.guild) + + try: + await global_fetch_message.delete() except discord.NotFound: pass await ctx.send(file=discord.File(chart, "chart.png"))