owlrandomshitbot/tgbot/shit/prompting.py
2025-03-03 04:10:37 +07:00

214 lines
7.7 KiB
Python

import argparse
import asyncio
import json
import random
import shlex
import threading
import aiohttp
import telegram
import websockets
from telegram import Update
from telegram.ext import ContextTypes
from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL
from tgbot.shit.custom_argparser import CustomArgParser
from tgbot.shit.handlers import escape_markdown
def filename2cdn(filename: str) -> str | None:
if not filename:
return None
return CDN_URL + filename
scheduled_prompts: dict = {}
async def ws_task():
ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/ws?token={IMAGE_GENERATOR_TOKEN}"
retry_delay = 1 # Initial delay (1s), will increase with backoff
while True:
try:
print(f"Connecting to WebSocket...")
async with websockets.connect(ws_url) as ws:
retry_delay = 1 # Reset delay on successful connection
while True:
message = await ws.recv()
message_dict = json.loads(message)
prompt_id = message_dict["id"]
image_url = filename2cdn(message_dict["image_filename"])
chat_id = scheduled_prompts.get(prompt_id)
print(f"Received message: {message}")
bot = telegram.Bot(TOKEN)
async with bot:
await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message))
del scheduled_prompts[prompt_id]
except websockets.ConnectionClosed as e:
print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...")
except Exception as e:
print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s
def start_ws_task():
"""Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot."""
print("starting the websocket task")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(ws_task())
thread = threading.Thread(target=start_ws_task, daemon=True)
thread.start()
async def enqueue_prompt_t2i(args: dict):
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompt/generic/txt2img/"
async with aiohttp.ClientSession() as session:
headers = {"Authorization":
"Bearer " + IMAGE_GENERATOR_TOKEN
}
result = await session.post(
url=url,
headers=headers,
json=args,
)
return await result.json()
def sdxl_image_dimension(value):
try:
value = int(value)
if value % 8 != 0:
raise argparse.ArgumentTypeError(f"{value} is not divisible by 8")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not an integer")
def denoise(value):
try:
value = float(value)
if value > 1.0 or value < 0:
raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a float")
def create_prompting_parser():
parser = CustomArgParser(
description="prompting - a command for interacting with image generators",
usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad
and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler'
--scheduler=normal --denoise=1.0""",
)
parser.add_argument("--checkpoint", type=str,
choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"],
help="checkpoint to use", required=True)
parser.add_argument("--prompt", type=str,
help="positive prompt, e.g. 'fox lady'", required=True)
parser.add_argument("--negative_prompt", type=str,
help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False)
parser.add_argument("--width", type=sdxl_image_dimension,
help="width, default=1024", default=1024)
parser.add_argument("--height", type=sdxl_image_dimension,
help="height, default=1024", default=1024)
parser.add_argument("--seed", type=int,
help="seed for noize generation, default is random", default=random.randint(0, 99999999))
parser.add_argument("--steps", type=int,
help="denoise steps, default=20", default=20)
parser.add_argument("--cfg_scale", "-cfg", type=float,
help="classifier-free guidance scale, default=8", default=8.0)
parser.add_argument("--sampler_name", type=str,
choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun',
'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu',
'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde',
'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc',
'uni_pc_bh2'],
help="sampler name, default is euler", default="euler")
parser.add_argument("--scheduler", type=str,
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'],
help="noise scheduler, default is normal", default="normal")
parser.add_argument("--denoise", type=denoise,
help="denoise level, default is 1, default=8", default=1.0)
return parser
def prompting_parse_args(args: str, parser: CustomArgParser):
args = parser.parse_args(shlex.split(args)).__dict__
return args
async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.message.text.replace("prompting", "").replace("", "--").replace("\n", "")
parser = create_prompting_parser()
try:
args = prompting_parse_args(query, parser)
except RuntimeError as e:
if getattr(parser, 'help_message', None):
help_message = parser.help_message
escaped_help_message = escape_markdown(help_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_help_message + "\n```",
parse_mode="MarkdownV2")
else:
error_message = str(e)
escaped_error_message = escape_markdown(error_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_error_message + "\n```",
parse_mode="MarkdownV2")
result = await enqueue_prompt_t2i(args)
if result["error"] == 0:
prompt_id = result['data']
scheduled_prompts[prompt_id] = update.effective_chat.id
print(scheduled_prompts)
return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}")
else:
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="Some weird incomprehensible error occurred :<")