import argparse import asyncio import json import random import shlex import threading import aiohttp import telegram import websockets from telegram import Update from telegram.ext import ContextTypes from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL from tgbot.shit.custom_argparser import CustomArgParser from tgbot.shit.handlers import escape_markdown def filename2cdn(filename: str) -> str | None: if not filename: return None return CDN_URL + filename scheduled_prompts: dict = {} async def ws_task(): ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/system/ws?token={IMAGE_GENERATOR_TOKEN}" retry_delay = 1 # Initial delay (1s), will increase with backoff while True: try: print(f"Connecting to WebSocket...") async with websockets.connect(ws_url) as ws: retry_delay = 1 # Reset delay on successful connection while True: message = await ws.recv() message_dict = json.loads(message) prompt_id = message_dict["id"] image_url = filename2cdn(message_dict["image_filename"]) chat_id = scheduled_prompts.get(prompt_id) print(f"Received message: {message}") bot = telegram.Bot(TOKEN) async with bot: await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message)) del scheduled_prompts[prompt_id] except websockets.ConnectionClosed as e: print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...") except Exception as e: print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...") await asyncio.sleep(retry_delay) retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s def start_ws_task(): """Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot.""" print("starting the websocket task") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(ws_task()) thread = threading.Thread(target=start_ws_task, daemon=True) thread.start() async def enqueue_prompt_t2i(args: dict): url = f"https://{IMAGE_GENERATOR_API_HOST}/prompting/txt2img/" async with aiohttp.ClientSession() as session: headers = {"Authorization": "Bearer " + IMAGE_GENERATOR_TOKEN } result = await session.post( url=url, headers=headers, json=args, ) return await result.json() def sdxl_image_dimension(value): try: value = int(value) if value % 8 != 0: raise argparse.ArgumentTypeError(f"{value} is not divisible by 8") return value except ValueError: raise argparse.ArgumentTypeError(f"{value} is not an integer") def lora(value): lora_ = value.split(":") if len(lora_) != 2: raise argparse.ArgumentTypeError(f"can't parse lora '{value}', " f"use --lora lora_name.safetensors:lora_strength") try: lora_[1] = float(lora_[1]) except ValueError: argparse.ArgumentTypeError(f"lora strength {lora_[1]} is not a float") return lora_ def denoise(value): try: value = float(value) if value > 1.0 or value < 0: raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level") return value except ValueError: raise argparse.ArgumentTypeError(f"{value} is not a float") def create_prompting_parser(): parser = CustomArgParser( description="prompting - a command for interacting with image generators", usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler' --scheduler=normal --denoise=1.0""", ) parser.add_argument("--checkpoint", type=str, choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"], help="checkpoint to use", required=True) parser.add_argument("--prompt", type=str, help="positive prompt, e.g. 'fox lady'", required=True) parser.add_argument("--negative_prompt", type=str, help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False) parser.add_argument("--width", type=sdxl_image_dimension, help="width, default=1024", default=1024) parser.add_argument("--height", type=sdxl_image_dimension, help="height, default=1024", default=1024) parser.add_argument("--seed", type=int, help="seed for noize generation, default is random", default=random.randint(0, 99999999)) parser.add_argument("--steps", type=int, help="denoise steps, default=20", default=20) parser.add_argument("--cfg_scale", "-cfg", type=float, help="classifier-free guidance scale, default=8", default=8.0) parser.add_argument("--sampler_name", type=str, choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun', 'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', 'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu', 'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc', 'uni_pc_bh2'], help="sampler name, default is euler", default="euler") parser.add_argument("--scheduler", type=str, choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'], help="noise scheduler, default is normal", default="normal") parser.add_argument("--denoise", type=denoise, help="denoise level, default is 1", default=1.0) parser.add_argument("--stop_at_clip_layer", type=denoise, help="clip skip, default is -1, ", default=-1) parser.add_argument("--lora", type=lora, action="append", help="add lora to prompt, e.g. --lora lora_name.safetensors:0.8. " "Multiple loras might be added") return parser def prompting_parse_args(args: str, parser: CustomArgParser): args = parser.parse_args(shlex.split(args)).__dict__ loras = [] if args['lora']: for lora_ in args['lora']: loras.append( { "name": lora_[0], "strength_model": lora_[1], "strength_clip": lora_[1] } ) del args["lora"] args["loras"] = loras return args async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE): query = update.message.text.replace("prompting", "").replace("—", "--").replace("\n", "") parser = create_prompting_parser() try: args = prompting_parse_args(query, parser) except RuntimeError as e: if getattr(parser, 'help_message', None): help_message = parser.help_message escaped_help_message = escape_markdown(help_message) return await context.bot.send_message(chat_id=update.effective_chat.id, text="```\n" + escaped_help_message + "\n```", parse_mode="MarkdownV2") else: error_message = str(e) escaped_error_message = escape_markdown(error_message) return await context.bot.send_message(chat_id=update.effective_chat.id, text="```\n" + escaped_error_message + "\n```", parse_mode="MarkdownV2") result = await enqueue_prompt_t2i(args) if result["error"] == 0: prompt_id = result['data'] scheduled_prompts[prompt_id] = update.effective_chat.id print(scheduled_prompts) return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}") else: return await context.bot.send_message(chat_id=update.effective_chat.id, text="Some weird incomprehensible error occurred :<")