owlrandomshitbot/tgbot/shit/prompting.py
2025-03-15 11:10:52 +07:00

297 lines
10 KiB
Python

import argparse
import asyncio
import json
import random
import shlex
import threading
from io import BytesIO
import aiohttp
import httpx
import telegram
import websockets
from telegram import Update
from telegram.ext import ContextTypes
from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL
from tgbot.shit.custom_argparser import CustomArgParser
from tgbot.shit.handlers import escape_markdown
def filename2cdn(filename: str) -> str | None:
if not filename:
return None
return CDN_URL + filename
scheduled_prompts: dict = {}
async def ws_task():
ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/system/ws?token={IMAGE_GENERATOR_TOKEN}"
retry_delay = 1 # Initial delay (1s), will increase with backoff
while True:
try:
print(f"Connecting to WebSocket...")
async with websockets.connect(ws_url) as ws:
retry_delay = 1 # Reset delay on successful connection
while True:
message = await ws.recv()
message_dict = json.loads(message)
prompt = message_dict["data"]
prompt_id = prompt["id"]
image_url = filename2cdn(prompt["image_filename"])
chat_id = scheduled_prompts.get(prompt_id)
print(f"Received message: {message}")
# Download the image
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Create a BytesIO object from the image data
image_bytes = BytesIO(response.content)
bot = telegram.Bot(TOKEN)
async with bot:
await bot.send_message(
chat_id=chat_id,
text=str(message)
)
await bot.send_document(
chat_id=chat_id,
document=image_bytes,
filename=prompt["image_filename"],
)
del scheduled_prompts[prompt_id]
except websockets.ConnectionClosed as e:
print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...")
except Exception as e:
print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s
def start_ws_task():
"""Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot."""
print("starting the websocket task")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(ws_task())
thread = threading.Thread(target=start_ws_task, daemon=True)
thread.start()
async def enqueue_prompt_t2i(args: dict):
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompting/txt2img/"
async with aiohttp.ClientSession() as session:
headers = {"Authorization":
"Bearer " + IMAGE_GENERATOR_TOKEN
}
result = await session.post(
url=url,
headers=headers,
json=args,
)
return await result.json()
async def list_models():
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompting/list_models/"
async with aiohttp.ClientSession() as session:
headers = {"Authorization":
"Bearer " + IMAGE_GENERATOR_TOKEN
}
result = await session.get(
url=url,
headers=headers,
)
return await result.text()
def sdxl_image_dimension(value):
try:
value = int(value)
if value % 8 != 0:
raise argparse.ArgumentTypeError(f"{value} is not divisible by 8")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not an integer")
def lora(value):
lora_ = value.split(":")
if len(lora_) != 2:
raise argparse.ArgumentTypeError(f"can't parse lora '{value}', "
f"use --lora lora_name.safetensors:lora_strength")
try:
lora_[1] = float(lora_[1])
except ValueError:
argparse.ArgumentTypeError(f"lora strength {lora_[1]} is not a float")
return lora_
def denoise(value):
try:
value = float(value)
if value > 1.0 or value < 0:
raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a float")
def create_prompting_parser():
parser = CustomArgParser(
description="prompting - a command for interacting with image generators",
usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad
and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler'
--scheduler=normal --denoise=1.0""",
)
parser.add_argument("--checkpoint", type=str,
help="checkpoint to use", required=True)
parser.add_argument("--prompt", type=str,
help="positive prompt, e.g. 'fox lady'", required=True)
parser.add_argument("--negative_prompt", type=str,
help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False)
parser.add_argument("--width", type=sdxl_image_dimension,
help="width, default=1024", default=1024)
parser.add_argument("--height", type=sdxl_image_dimension,
help="height, default=1024", default=1024)
parser.add_argument("--seed", type=int,
help="seed for noize generation, default is random", default=random.randint(0, 99999999))
parser.add_argument("--steps", type=int,
help="denoise steps, default=20", default=20)
parser.add_argument("--cfg_scale", "-cfg", type=float,
help="classifier-free guidance scale, default=8", default=8.0)
parser.add_argument("--sampler_name", type=str,
choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun',
'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu',
'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde',
'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc',
'uni_pc_bh2'],
help="sampler name, default is euler", default="euler")
parser.add_argument("--scheduler", type=str,
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'],
help="noise scheduler, default is normal", default="normal")
parser.add_argument("--denoise", type=denoise,
help="denoise level, default is 1", default=1.0)
parser.add_argument("--stop_at_clip_layer", type=int,
help="clip skip, default is -2, ", default=-2)
parser.add_argument("--lora", type=lora, action="append",
help="add lora to prompt, e.g. --lora lora_name.safetensors:0.8. "
"Multiple loras might be added")
return parser
def prompting_parse_args(args: str, parser: CustomArgParser):
args = parser.parse_args(shlex.split(args)).__dict__
loras = []
if args['lora']:
for lora_ in args['lora']:
loras.append(
{
"name": lora_[0],
"strength_model": lora_[1],
"strength_clip": lora_[1]
}
)
del args["lora"]
args["loras"] = loras
return args
async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.message.text
if query.startswith("prompting list models"):
models = await list_models()
return await context.bot.send_message(chat_id=update.effective_chat.id, text=models)
query = update.message.text.replace("prompting", "").replace("", "--").replace("\n", " ")
parser = create_prompting_parser()
try:
args = prompting_parse_args(query, parser)
except RuntimeError as e:
if getattr(parser, 'help_message', None):
help_message = parser.help_message
escaped_help_message = escape_markdown(help_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_help_message + "\n```",
parse_mode="MarkdownV2")
else:
error_message = str(e)
escaped_error_message = escape_markdown(error_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_error_message + "\n```",
parse_mode="MarkdownV2")
result = await enqueue_prompt_t2i(args)
if result["error"] == "OK":
prompt_id = result['data']
scheduled_prompts[prompt_id] = update.effective_chat.id
print(scheduled_prompts)
return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}")
else:
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="Some weird incomprehensible error occurred :<\n"
+ json.dumps(result))