214 lines
7.7 KiB
Python
214 lines
7.7 KiB
Python
import argparse
|
|
import asyncio
|
|
import json
|
|
import random
|
|
import shlex
|
|
import threading
|
|
|
|
import aiohttp
|
|
import telegram
|
|
import websockets
|
|
from telegram import Update
|
|
from telegram.ext import ContextTypes
|
|
|
|
from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL
|
|
from tgbot.shit.custom_argparser import CustomArgParser
|
|
from tgbot.shit.handlers import escape_markdown
|
|
|
|
|
|
def filename2cdn(filename: str) -> str | None:
|
|
if not filename:
|
|
return None
|
|
return CDN_URL + filename
|
|
|
|
|
|
scheduled_prompts: dict = {}
|
|
|
|
|
|
async def ws_task():
|
|
ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/system/ws?token={IMAGE_GENERATOR_TOKEN}"
|
|
retry_delay = 1 # Initial delay (1s), will increase with backoff
|
|
|
|
while True:
|
|
try:
|
|
print(f"Connecting to WebSocket...")
|
|
async with websockets.connect(ws_url) as ws:
|
|
retry_delay = 1 # Reset delay on successful connection
|
|
|
|
while True:
|
|
message = await ws.recv()
|
|
message_dict = json.loads(message)
|
|
|
|
prompt_id = message_dict["id"]
|
|
image_url = filename2cdn(message_dict["image_filename"])
|
|
chat_id = scheduled_prompts.get(prompt_id)
|
|
|
|
print(f"Received message: {message}")
|
|
|
|
bot = telegram.Bot(TOKEN)
|
|
|
|
async with bot:
|
|
await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message))
|
|
|
|
del scheduled_prompts[prompt_id]
|
|
|
|
except websockets.ConnectionClosed as e:
|
|
print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...")
|
|
except Exception as e:
|
|
print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...")
|
|
|
|
await asyncio.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s
|
|
|
|
|
|
def start_ws_task():
|
|
"""Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot."""
|
|
print("starting the websocket task")
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(ws_task())
|
|
|
|
|
|
thread = threading.Thread(target=start_ws_task, daemon=True)
|
|
thread.start()
|
|
|
|
|
|
async def enqueue_prompt_t2i(args: dict):
|
|
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompting/txt2img/"
|
|
async with aiohttp.ClientSession() as session:
|
|
headers = {"Authorization":
|
|
"Bearer " + IMAGE_GENERATOR_TOKEN
|
|
}
|
|
|
|
result = await session.post(
|
|
url=url,
|
|
headers=headers,
|
|
json=args,
|
|
)
|
|
|
|
return await result.json()
|
|
|
|
|
|
def sdxl_image_dimension(value):
|
|
try:
|
|
value = int(value)
|
|
if value % 8 != 0:
|
|
raise argparse.ArgumentTypeError(f"{value} is not divisible by 8")
|
|
return value
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(f"{value} is not an integer")
|
|
|
|
|
|
def denoise(value):
|
|
try:
|
|
value = float(value)
|
|
if value > 1.0 or value < 0:
|
|
raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level")
|
|
return value
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(f"{value} is not a float")
|
|
|
|
|
|
def create_prompting_parser():
|
|
parser = CustomArgParser(
|
|
description="prompting - a command for interacting with image generators",
|
|
usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad
|
|
and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler'
|
|
--scheduler=normal --denoise=1.0""",
|
|
)
|
|
|
|
parser.add_argument("--checkpoint", type=str,
|
|
choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"],
|
|
help="checkpoint to use", required=True)
|
|
|
|
parser.add_argument("--prompt", type=str,
|
|
help="positive prompt, e.g. 'fox lady'", required=True)
|
|
|
|
parser.add_argument("--negative_prompt", type=str,
|
|
help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False)
|
|
|
|
parser.add_argument("--width", type=sdxl_image_dimension,
|
|
help="width, default=1024", default=1024)
|
|
|
|
parser.add_argument("--height", type=sdxl_image_dimension,
|
|
help="height, default=1024", default=1024)
|
|
|
|
parser.add_argument("--seed", type=int,
|
|
help="seed for noize generation, default is random", default=random.randint(0, 99999999))
|
|
|
|
parser.add_argument("--steps", type=int,
|
|
help="denoise steps, default=20", default=20)
|
|
|
|
parser.add_argument("--cfg_scale", "-cfg", type=float,
|
|
help="classifier-free guidance scale, default=8", default=8.0)
|
|
|
|
parser.add_argument("--sampler_name", type=str,
|
|
choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun',
|
|
'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
|
|
'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu',
|
|
'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde',
|
|
'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc',
|
|
'uni_pc_bh2'],
|
|
help="sampler name, default is euler", default="euler")
|
|
|
|
parser.add_argument("--scheduler", type=str,
|
|
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'],
|
|
help="noise scheduler, default is normal", default="normal")
|
|
|
|
parser.add_argument("--denoise", type=denoise,
|
|
help="denoise level, default is 1, default=8", default=1.0)
|
|
|
|
return parser
|
|
|
|
|
|
def prompting_parse_args(args: str, parser: CustomArgParser):
|
|
args = parser.parse_args(shlex.split(args)).__dict__
|
|
return args
|
|
|
|
|
|
async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
query = update.message.text.replace("prompting", "").replace("—", "--").replace("\n", "")
|
|
|
|
parser = create_prompting_parser()
|
|
|
|
try:
|
|
|
|
args = prompting_parse_args(query, parser)
|
|
|
|
except RuntimeError as e:
|
|
|
|
if getattr(parser, 'help_message', None):
|
|
help_message = parser.help_message
|
|
|
|
escaped_help_message = escape_markdown(help_message)
|
|
|
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
|
text="```\n" + escaped_help_message + "\n```",
|
|
parse_mode="MarkdownV2")
|
|
else:
|
|
|
|
error_message = str(e)
|
|
|
|
escaped_error_message = escape_markdown(error_message)
|
|
|
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
|
text="```\n" + escaped_error_message + "\n```",
|
|
parse_mode="MarkdownV2")
|
|
|
|
result = await enqueue_prompt_t2i(args)
|
|
|
|
if result["error"] == 0:
|
|
|
|
prompt_id = result['data']
|
|
|
|
scheduled_prompts[prompt_id] = update.effective_chat.id
|
|
|
|
print(scheduled_prompts)
|
|
|
|
return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}")
|
|
|
|
else:
|
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
|
text="Some weird incomprehensible error occurred :<")
|
|
|