Files
aikaterna-cogs/dalle/dalle.py
2022-06-16 12:17:12 -07:00

110 lines
4.0 KiB
Python

import aiohttp
import base64
import discord
from discord.http import Route
import io
import json
from typing import List
from redbot.core import commands
class DallE(commands.Cog):
"""Dall-E mini image generation"""
def __init__(self, bot):
self.bot = bot
async def red_delete_data_for_user(self, **kwargs):
"""Nothing to delete."""
return
@commands.max_concurrency(3, commands.BucketType.default)
@commands.command()
@commands.guild_only()
async def generate(self, ctx: commands.Context, *, prompt: str):
"""
Generate images through Dall-E mini.
https://huggingface.co/spaces/dalle-mini/dalle-mini
"""
embed_links = ctx.channel.permissions_for(ctx.guild.me).embed_links
if not embed_links:
return await ctx.send("I need the `Embed Links` permission here before you can use this command.")
status_msg = await ctx.send("Image generator starting up, please be patient. This will take a very long time.")
images = None
attempt = 0
async with ctx.typing():
while not images:
if attempt < 100:
attempt += 1
if attempt % 3 == 0:
status = f"This will take a very long time. Once a response is acquired, this counter will pause while processing.\n[attempt `{attempt}/100`]"
try:
await status_msg.edit(content=status)
except discord.NotFound:
status_msg = await ctx.send(status)
images = await self.generate_images(prompt)
file_images = [discord.File(images[i], filename=f"{i}.png") for i in range(len(images))]
if len(file_images) == 0:
return await ctx.send(f"I didn't find anything for `{prompt}`.")
file_images = file_images[:4]
embed = discord.Embed(
colour=await ctx.embed_color(),
title="Dall-E Mini results",
url="https://huggingface.co/spaces/dalle-mini/dalle-mini",
)
embeds = []
for i, image in enumerate(file_images):
em = embed.copy()
em.set_image(url=f"attachment://{i}.png")
em.set_footer(text=f"Results for: {prompt}, requested by {ctx.author}\nView this output on a desktop client for best results.")
embeds.append(em)
form = []
payload = {"embeds": [e.to_dict() for e in embeds]}
form.append({"name": "payload_json", "value": discord.utils.to_json(payload)})
if len(file_images) == 1:
file = file_images[0]
form.append(
{
"name": "file",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
else:
for index, file in enumerate(file_images):
form.append(
{
"name": f"file{index}",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
try:
await status_msg.delete()
except discord.NotFound:
pass
r = Route("POST", "/channels/{channel_id}/messages", channel_id=ctx.channel.id)
await ctx.guild._state.http.request(r, form=form, files=file_images)
@staticmethod
async def generate_images(prompt: str) -> List[io.BytesIO]:
async with aiohttp.ClientSession() as session:
async with session.post("https://bf.dallemini.ai/generate", json={"prompt": prompt}) as response:
if response.status == 200:
response_data = await response.json()
images = [io.BytesIO(base64.decodebytes(bytes(image, "utf-8"))) for image in response_data["images"]]
return images
else:
return None