Skip to content

Commit

Permalink
Format Code with Black
Browse files Browse the repository at this point in the history
  • Loading branch information
Zingzy committed Apr 13, 2024
1 parent 9a8000b commit c3e5dbe
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 78 deletions.
205 changes: 172 additions & 33 deletions cogs/imagine_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from utils import *
from constants import *


class ImagineButtonView(discord.ui.View):
def __init__(self, link: str = None):
super().__init__(timeout=None)
Expand All @@ -14,15 +15,35 @@ def __init__(self, link: str = None):
if link is not None:
self.add_item(discord.ui.Button(label="Link", url=self.link))

@discord.ui.button(style=discord.ButtonStyle.secondary, custom_id="regenerate-button", emoji="<:redo:1187101382101180456>")
async def regenerate(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
style=discord.ButtonStyle.secondary,
custom_id="regenerate-button",
emoji="<:redo:1187101382101180456>",
)
async def regenerate(
self, interaction: discord.Interaction, button: discord.ui.Button
):
message_id = interaction.message.id
await interaction.response.send_message(embed=discord.Embed(title="Regenerating Your Image", description="Please wait while we generate your image", color=discord.Color.blurple()), ephemeral=True)
await interaction.response.send_message(
embed=discord.Embed(
title="Regenerating Your Image",
description="Please wait while we generate your image",
color=discord.Color.blurple(),
),
ephemeral=True,
)

message_data = get_prompt_data(message_id)

if not message_data:
await interaction.followup.send(embed=discord.Embed(title="Error", description="Message not found", color=discord.Color.red()), ephemeral=True)
await interaction.followup.send(
embed=discord.Embed(
title="Error",
description="Message not found",
color=discord.Color.red(),
),
ephemeral=True,
)
return

start = datetime.datetime.now()
Expand All @@ -37,10 +58,19 @@ async def regenerate(self, interaction: discord.Interaction, button: discord.ui.
enhance = message_data["enhance"]

try:
dic, image, is_nsfw = await generate_image(prompt, width, height, model, negative, cached, nologo, enhance)
dic, image, is_nsfw = await generate_image(
prompt, width, height, model, negative, cached, nologo, enhance
)
except Exception as e:
print(e)
await interaction.followup.send(embed=discord.Embed(title="Error", description=f"Error generating image : {e}", color=discord.Color.red()), ephemeral=True)
await interaction.followup.send(
embed=discord.Embed(
title="Error",
description=f"Error generating image : {e}",
color=discord.Color.red(),
),
ephemeral=True,
)
return

image_file = discord.File(image, filename="image.png")
Expand All @@ -52,7 +82,9 @@ async def regenerate(self, interaction: discord.Interaction, button: discord.ui.

context = f"## {prompt} - {interaction.user.mention}\n### Model - `{model}` | Time Taken - `{round(time_taken.total_seconds(), 2)} s`\n### Width - `{width} px` | Height - `{height} px`\n### Enchance - `{enhance}`"

response = await interaction.channel.send(context, file=image_file, view=ImagineButtonView(link=dic["bookmark_url"]))
response = await interaction.channel.send(
context, file=image_file, view=ImagineButtonView(link=dic["bookmark_url"])
)

dic["_id"] = response.id
dic["channel_id"] = interaction.channel.id
Expand All @@ -78,7 +110,12 @@ async def regenerate(self, interaction: discord.Interaction, button: discord.ui.
update_user_data(interaction.user.id, user_data)
save_prompt_data(message_id, dic)

@discord.ui.button(label="0", style=discord.ButtonStyle.secondary, custom_id="like-button", emoji="<:like:1187101385230143580>")
@discord.ui.button(
label="0",
style=discord.ButtonStyle.secondary,
custom_id="like-button",
emoji="<:like:1187101385230143580>",
)
async def like(self, interaction: discord.Interaction, button: discord.ui.Button):
try:
id = interaction.message.id
Expand Down Expand Up @@ -118,25 +155,49 @@ async def like(self, interaction: discord.Interaction, button: discord.ui.Button
return
except Exception as e:
print(e)
interaction.response.send_message(embed=discord.Embed(title="Error Liking the Image", description=f"{e}", color=discord.Color.red()), ephemeral=True)

interaction.response.send_message(
embed=discord.Embed(
title="Error Liking the Image",
description=f"{e}",
color=discord.Color.red(),
),
ephemeral=True,
)

@discord.ui.button(label = "0", style=discord.ButtonStyle.secondary, custom_id="bookmark-button", emoji="<:save:1187101389822902344>")
async def bookmark(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="0",
style=discord.ButtonStyle.secondary,
custom_id="bookmark-button",
emoji="<:save:1187101389822902344>",
)
async def bookmark(
self, interaction: discord.Interaction, button: discord.ui.Button
):
try:
id = interaction.message.id
message_data = get_prompt_data(id)
bookmarks = message_data["bookmarks"]

if interaction.user.id in bookmarks:
await interaction.response.send_message(embed=discord.Embed(title="Error", description="You have already bookmarked this image", color=discord.Color.red()), ephemeral=True)
await interaction.response.send_message(
embed=discord.Embed(
title="Error",
description="You have already bookmarked this image",
color=discord.Color.red(),
),
ephemeral=True,
)
else:
bookmarks.append(interaction.user.id)
update_prompt_data(id, {"bookmarks": bookmarks})
button.label = f"{len(bookmarks)}"
await interaction.response.edit_message(view=self)

embed = discord.Embed(title=f"Prompt : {message_data['prompt']}", description=f"url : {message_data['bookmark_url']}", color=discord.Color.og_blurple())
embed = discord.Embed(
title=f"Prompt : {message_data['prompt']}",
description=f"url : {message_data['bookmark_url']}",
color=discord.Color.og_blurple(),
)
embed.set_image(url=message_data["bookmark_url"])

await interaction.user.send(embed=embed)
Expand All @@ -159,9 +220,20 @@ async def bookmark(self, interaction: discord.Interaction, button: discord.ui.Bu

except Exception as e:
print(e)
await interaction.response.send_message(embed=discord.Embed(title="Error Bookmarking the Image", description=f"{e}", color=discord.Color.red()), ephemeral=True)
await interaction.response.send_message(
embed=discord.Embed(
title="Error Bookmarking the Image",
description=f"{e}",
color=discord.Color.red(),
),
ephemeral=True,
)

@discord.ui.button(style=discord.ButtonStyle.red, custom_id="delete-button", emoji="<:delete:1187102382312652800>")
@discord.ui.button(
style=discord.ButtonStyle.red,
custom_id="delete-button",
emoji="<:delete:1187102382312652800>",
)
async def delete(self, interaction: discord.Interaction, button: discord.ui.Button):
try:
data = get_prompt_data(interaction.message.id)
Expand All @@ -174,7 +246,14 @@ async def delete(self, interaction: discord.Interaction, button: discord.ui.Butt
pass

if interaction.user.id != author_id:
await interaction.response.send_message(embed=discord.Embed(title="Error", description="You can only delete your own images", color=discord.Color.red()), ephemeral=True)
await interaction.response.send_message(
embed=discord.Embed(
title="Error",
description="You can only delete your own images",
color=discord.Color.red(),
),
ephemeral=True,
)
return

delete_prompt_data(interaction.message.id)
Expand Down Expand Up @@ -205,7 +284,15 @@ async def delete(self, interaction: discord.Interaction, button: discord.ui.Butt

except Exception as e:
print(e)
await interaction.response.send_message(embed=discord.Embed(title="Error Deleting the Image", description=f"{e}", color=discord.Color.red()), ephemeral=True)
await interaction.response.send_message(
embed=discord.Embed(
title="Error Deleting the Image",
description=f"{e}",
color=discord.Color.red(),
),
ephemeral=True,
)


class Imagine(commands.Cog):
def __init__(self, bot):
Expand All @@ -217,22 +304,61 @@ async def cog_load(self):

@app_commands.command(name="pollinate", description="Generate AI Images")
@app_commands.choices(
model=[
app_commands.Choice(name=choice, value=choice) for choice in MODELS
],
)
model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS],
)
@app_commands.guild_only()
@app_commands.checks.cooldown(1, 15)
@app_commands.describe(prompt="Imagine a prompt", model=f"The AI model to use for generating the image. Default is {MODELS[0]}", height="Height of the image", width="Width of the image", negative="The things not to include in the image", cached="Removes the image seed", nologo="Remove the logo", enhance="Disables Prompt enhancing if set to False", private="Only you can see the generated Image if set to True")
async def imagine_command(self, interaction, prompt:str, model: app_commands.Choice[str] = MODELS[0], width:int = 1000, height:int = 1000, negative:str|None = None, cached:bool = False, nologo:bool = False, enhance:bool = True, private:bool = False):
await interaction.response.send_message(embed=discord.Embed(title="Generating Image", description="Please wait while we generate your image", color=discord.Color.blurple()), ephemeral=True)
@app_commands.describe(
prompt="Imagine a prompt",
model=f"The AI model to use for generating the image. Default is {MODELS[0]}",
height="Height of the image",
width="Width of the image",
negative="The things not to include in the image",
cached="Removes the image seed",
nologo="Remove the logo",
enhance="Disables Prompt enhancing if set to False",
private="Only you can see the generated Image if set to True",
)
async def imagine_command(
self,
interaction,
prompt: str,
model: app_commands.Choice[str] = MODELS[0],
width: int = 1000,
height: int = 1000,
negative: str | None = None,
cached: bool = False,
nologo: bool = False,
enhance: bool = True,
private: bool = False,
):
await interaction.response.send_message(
embed=discord.Embed(
title="Generating Image",
description="Please wait while we generate your image",
color=discord.Color.blurple(),
),
ephemeral=True,
)

if len(prompt) > 1500:
await interaction.channel.send(embed=discord.Embed(title="Error", description="Prompt must be less than 1500 characters", color=discord.Color.red()))
await interaction.channel.send(
embed=discord.Embed(
title="Error",
description="Prompt must be less than 1500 characters",
color=discord.Color.red(),
)
)
return

if width < 16 or height < 16:
await interaction.channel.send(embed=discord.Embed(title="Error", description="Width and Height must be greater than 16", color=discord.Color.red()))
await interaction.channel.send(
embed=discord.Embed(
title="Error",
description="Width and Height must be greater than 16",
color=discord.Color.red(),
)
)
return

try:
Expand All @@ -243,11 +369,20 @@ async def imagine_command(self, interaction, prompt:str, model: app_commands.Cho
start = datetime.datetime.now()

try:
dic, image, is_nsfw = await generate_image(prompt, width, height, model, negative, cached, nologo, enhance, private)
dic, image, is_nsfw = await generate_image(
prompt, width, height, model, negative, cached, nologo, enhance, private
)
except Exception as e:
print(e)
print("Error in Imagine Cog")
await interaction.followup.send(embed=discord.Embed(title="Error", description=f"Error generating image : {e}", color=discord.Color.red()), ephemeral=True)
await interaction.followup.send(
embed=discord.Embed(
title="Error",
description=f"Error generating image : {e}",
color=discord.Color.red(),
),
ephemeral=True,
)
return

image_file = discord.File(image, filename="image.png")
Expand All @@ -262,10 +397,14 @@ async def imagine_command(self, interaction, prompt:str, model: app_commands.Cho
context = f"## {prompt} - {interaction.user.mention}\n### Model - `{model}` | Time Taken - `{round(time_taken.total_seconds(), 2)} s`\n### Width - `{width} px` | Height - `{height} px`\n### Enchance - `{enhance}`"

if private:
response = await interaction.followup.send(context, file=image_file, ephemeral=True)
response = await interaction.followup.send(
context, file=image_file, ephemeral=True
)
return
else:
response = await interaction.channel.send(context, file=image_file, view=view)
response = await interaction.channel.send(
context, file=image_file, view=view
)

message_id = response.id
dic["_id"] = message_id
Expand Down Expand Up @@ -293,7 +432,6 @@ async def imagine_command(self, interaction, prompt:str, model: app_commands.Cho
update_user_data(interaction.user.id, user_data)
save_prompt_data(message_id, dic)


@imagine_command.error
async def imagine_command_error(
self, interaction: discord.Interaction, error: app_commands.AppCommandError
Expand All @@ -318,6 +456,7 @@ async def imagine_command_error(

await interaction.response.send_message(embed=embed, ephemeral=True)


async def setup(bot):
await bot.add_cog(Imagine(bot))
print("Imagine cog loaded")
print("Imagine cog loaded")
Loading

0 comments on commit c3e5dbe

Please sign in to comment.