import argparse import asyncio import json import random import shlex import threading import aiohttp import telegram import websockets from telegram import Update from telegram.ext import ContextTypes from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL from tgbot.shit.custom_argparser import CustomArgParser from tgbot.shit.handlers import escape_markdown def filename2cdn(filename: str) -> str | None: if not filename: return None return CDN_URL + filename scheduled_prompts: dict = {} async def ws_task(): ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/system/ws?token={IMAGE_GENERATOR_TOKEN}" retry_delay = 1 # Initial delay (1s), will increase with backoff while True: try: print(f"Connecting to WebSocket...") async with websockets.connect(ws_url) as ws: retry_delay = 1 # Reset delay on successful connection while True: message = await ws.recv() message_dict = json.loads(message) prompt_id = message_dict["id"] image_url = filename2cdn(message_dict["image_filename"]) chat_id = scheduled_prompts.get(prompt_id) print(f"Received message: {message}") bot = telegram.Bot(TOKEN) async with bot: await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message)) del scheduled_prompts[prompt_id] except websockets.ConnectionClosed as e: print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...") except Exception as e: print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...") await asyncio.sleep(retry_delay) retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s def start_ws_task(): """Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot.""" print("starting the websocket task") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(ws_task()) thread = threading.Thread(target=start_ws_task, daemon=True) thread.start() async def enqueue_prompt_t2i(args: dict): url = f"https://{IMAGE_GENERATOR_API_HOST}/prompting/txt2img/" async with aiohttp.ClientSession() as session: headers = {"Authorization": "Bearer " + IMAGE_GENERATOR_TOKEN } result = await session.post( url=url, headers=headers, json=args, ) return await result.json() def sdxl_image_dimension(value): try: value = int(value) if value % 8 != 0: raise argparse.ArgumentTypeError(f"{value} is not divisible by 8") return value except ValueError: raise argparse.ArgumentTypeError(f"{value} is not an integer") def denoise(value): try: value = float(value) if value > 1.0 or value < 0: raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level") return value except ValueError: raise argparse.ArgumentTypeError(f"{value} is not a float") def create_prompting_parser(): parser = CustomArgParser( description="prompting - a command for interacting with image generators", usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler' --scheduler=normal --denoise=1.0""", ) parser.add_argument("--checkpoint", type=str, choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"], help="checkpoint to use", required=True) parser.add_argument("--prompt", type=str, help="positive prompt, e.g. 'fox lady'", required=True) parser.add_argument("--negative_prompt", type=str, help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False) parser.add_argument("--width", type=sdxl_image_dimension, help="width, default=1024", default=1024) parser.add_argument("--height", type=sdxl_image_dimension, help="height, default=1024", default=1024) parser.add_argument("--seed", type=int, help="seed for noize generation, default is random", default=random.randint(0, 99999999)) parser.add_argument("--steps", type=int, help="denoise steps, default=20", default=20) parser.add_argument("--cfg_scale", "-cfg", type=float, help="classifier-free guidance scale, default=8", default=8.0) parser.add_argument("--sampler_name", type=str, choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun', 'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', 'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu', 'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc', 'uni_pc_bh2'], help="sampler name, default is euler", default="euler") parser.add_argument("--scheduler", type=str, choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'], help="noise scheduler, default is normal", default="normal") parser.add_argument("--denoise", type=denoise, help="denoise level, default is 1, default=8", default=1.0) return parser def prompting_parse_args(args: str, parser: CustomArgParser): args = parser.parse_args(shlex.split(args)).__dict__ return args async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE): query = update.message.text.replace("prompting", "").replace("—", "--").replace("\n", "") parser = create_prompting_parser() try: args = prompting_parse_args(query, parser) except RuntimeError as e: if getattr(parser, 'help_message', None): help_message = parser.help_message escaped_help_message = escape_markdown(help_message) return await context.bot.send_message(chat_id=update.effective_chat.id, text="```\n" + escaped_help_message + "\n```", parse_mode="MarkdownV2") else: error_message = str(e) escaped_error_message = escape_markdown(error_message) return await context.bot.send_message(chat_id=update.effective_chat.id, text="```\n" + escaped_error_message + "\n```", parse_mode="MarkdownV2") result = await enqueue_prompt_t2i(args) if result["error"] == 0: prompt_id = result['data'] scheduled_prompts[prompt_id] = update.effective_chat.id print(scheduled_prompts) return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}") else: return await context.bot.send_message(chat_id=update.effective_chat.id, text="Some weird incomprehensible error occurred :<")