From 2bb1e2dccb7737560770011f02a9234fdd6ee806 Mon Sep 17 00:00:00 2001 From: owl Date: Mon, 3 Mar 2025 04:10:37 +0700 Subject: [PATCH] add prompting --- Dockerfile.tgbot | 6 +- bot.py | 17 +-- docker-compose.yml | 2 + tgbot/config.py | 4 + tgbot/requirements.txt | 4 +- tgbot/shit/custom_argparser.py | 35 ++++++ tgbot/shit/handlers.py | 6 +- tgbot/shit/hentai_argparser.py | 38 +----- tgbot/shit/prompting.py | 213 +++++++++++++++++++++++++++++++++ 9 files changed, 277 insertions(+), 48 deletions(-) create mode 100644 tgbot/shit/custom_argparser.py create mode 100644 tgbot/shit/prompting.py diff --git a/Dockerfile.tgbot b/Dockerfile.tgbot index cf7df1e..d89c754 100644 --- a/Dockerfile.tgbot +++ b/Dockerfile.tgbot @@ -1,8 +1,10 @@ FROM python:3.11 +COPY tgbot/requirements.txt . + +RUN pip install -r requirements.txt + COPY tgbot tgbot COPY bot.py bot.py -RUN pip install -r tgbot/requirements.txt - CMD python3 bot.py diff --git a/bot.py b/bot.py index cd3c86a..480c993 100644 --- a/bot.py +++ b/bot.py @@ -5,6 +5,7 @@ import logging from tgbot.config import TOKEN from tgbot.shit.handlers import handle_xitter, handle_red_ebalo, handle_hentai, handle_cute_button +from tgbot.shit.prompting import gen_image from tgbot.shit.render import render_text_on_image, image_to_string from io import BytesIO @@ -20,9 +21,13 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): 'AND I AM A PERSON OF COMPLETE AND TOTAL DELUSION!') -async def respond_with_picture(update: Update, context: ContextTypes.DEFAULT_TYPE): - file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text) - await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file) +async def handle_text(update: Update, context: ContextTypes.DEFAULT_TYPE): + + if not update.message.text.startswith("prompting"): + file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text) + await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file) + else: + await gen_image(update, context) async def image2string(update: Update, context: ContextTypes.DEFAULT_TYPE): @@ -52,7 +57,6 @@ async def callback_query_handler(update: Update, context: CallbackContext): else: await query.answer() - async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE): query = update.inline_query.query if not query: @@ -72,17 +76,16 @@ async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE): await handle_red_ebalo(update, context) - def main(): application = ApplicationBuilder().token(TOKEN).build() start_handler = CommandHandler('start', start) - picture_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, respond_with_picture) + text_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, handle_text) application.add_handler(start_handler) application.add_handler(MessageHandler(filters.PHOTO, image2string)) - application.add_handler(picture_handler) + application.add_handler(text_handler) application.add_handler(InlineQueryHandler(handle_inline)) application.add_handler(CallbackQueryHandler(callback_query_handler)) diff --git a/docker-compose.yml b/docker-compose.yml index f182347..7ce451a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,6 +18,7 @@ services: - 8.8.8.8 owlrandomshitbot: + restart: unless-stopped image: owlrandomshitbot build: context: . @@ -28,6 +29,7 @@ services: - vpn booru-api: + restart: unless-stopped image: booru-api build: context: . diff --git a/tgbot/config.py b/tgbot/config.py index f8eaeba..63d4927 100644 --- a/tgbot/config.py +++ b/tgbot/config.py @@ -8,3 +8,7 @@ STICKER_SHITHOLE = -1002471390283 DANBOORU_USERNAME = "owlrandomshitbot" DANBOORU_API_KEY = os.getenv("DANBOORU_API_KEY") + +IMAGE_GENERATOR_TOKEN: str = os.getenv("IMAGE_GENERATOR_TOKEN") +IMAGE_GENERATOR_API_HOST: str = os.getenv("IMAGE_GENERATOR_API_HOST") +CDN_URL: str = os.getenv("CDN_URL") diff --git a/tgbot/requirements.txt b/tgbot/requirements.txt index d5c94fa..d2243c3 100644 --- a/tgbot/requirements.txt +++ b/tgbot/requirements.txt @@ -8,4 +8,6 @@ pillow==11.0.0 python-telegram-bot==21.6 sniffio==1.3.1 requests==2.32.3 -Pybooru==4.2.2 \ No newline at end of file +Pybooru==4.2.2 +aiohttp==3.11.13 +websockets==15.0 \ No newline at end of file diff --git a/tgbot/shit/custom_argparser.py b/tgbot/shit/custom_argparser.py new file mode 100644 index 0000000..c6931f2 --- /dev/null +++ b/tgbot/shit/custom_argparser.py @@ -0,0 +1,35 @@ +import argparse +import io + + +class CustomArgParser(argparse.ArgumentParser): + def __init__( + self, + *args, **kwargs, + ): + super().__init__(*args, **kwargs) + self.help_message = None + + def print_help(self, file=None): + # Store the help message in a buffer instead of printing + if file is None: + help_buffer = io.StringIO() + super().print_help(file=help_buffer) + self.help_message = help_buffer.getvalue() + help_buffer.close() + else: + super().print_help(file=file) + + def parse_args(self, args=None, namespace=None): + # Check for --help manually to avoid stdout output + if '--help' in args: + self.print_help() + raise RuntimeError("Help requested") + return super().parse_args(args, namespace) + + def exit(self, status=0, message=None): + raise RuntimeError(message) + + def error(self, message): + raise RuntimeError(f"Error: {message}") + diff --git a/tgbot/shit/handlers.py b/tgbot/shit/handlers.py index 8a49762..3d3ceba 100644 --- a/tgbot/shit/handlers.py +++ b/tgbot/shit/handlers.py @@ -7,7 +7,7 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, InlineQ from telegram.ext import ContextTypes from tgbot.config import STICKER_SHITHOLE, BOORU_API_URL, INLINE_QUERY_CACHE_SECONDS, DANBOORU_API_KEY -from tgbot.shit.hentai_argparser import parse_args, create_parser +from tgbot.shit.hentai_argparser import hentai_parse_args, create_hentai_parser from tgbot.shit.render import render_text_on_image import re @@ -74,11 +74,11 @@ async def handle_cute_button(update: Update, context: ContextTypes.DEFAULT_TYPE) async def handle_hentai(update: Update, context: ContextTypes.DEFAULT_TYPE): query = update.inline_query.query.replace("hentai", "").replace("—", "--") - parser = create_parser() + parser = create_hentai_parser() try: - args = parse_args(query, parser) + args = hentai_parse_args(query, parser) except RuntimeError as e: diff --git a/tgbot/shit/hentai_argparser.py b/tgbot/shit/hentai_argparser.py index cbd73aa..7d9e1b8 100644 --- a/tgbot/shit/hentai_argparser.py +++ b/tgbot/shit/hentai_argparser.py @@ -1,41 +1,9 @@ -import argparse -import io import shlex - -class CustomArgParser(argparse.ArgumentParser): - def __init__( - self, - *args, **kwargs, - ): - super().__init__(*args, **kwargs) - self.help_message = None - - def print_help(self, file=None): - # Store the help message in a buffer instead of printing - if file is None: - help_buffer = io.StringIO() - super().print_help(file=help_buffer) - self.help_message = help_buffer.getvalue() - help_buffer.close() - else: - super().print_help(file=file) - - def parse_args(self, args=None, namespace=None): - # Check for --help manually to avoid stdout output - if '--help' in args: - self.print_help() - raise RuntimeError("Help requested") - return super().parse_args(args, namespace) - - def exit(self, status=0, message=None): - raise RuntimeError(message) - - def error(self, message): - raise RuntimeError(f"Error: {message}") +from tgbot.shit.custom_argparser import CustomArgParser -def create_parser(): +def create_hentai_parser(): parser = CustomArgParser( description="A command to search random stuff on various image boorus. Feel free to use it irresponsibly!", usage="@owlrandomshitbot hentai BOORU_NAME --tags TAG1,TAG2,TAG3 -l LIMIT --random -s", @@ -68,7 +36,7 @@ def create_parser(): return parser -def parse_args(args: str, parser: CustomArgParser): +def hentai_parse_args(args: str, parser: CustomArgParser): args = parser.parse_args(shlex.split(args)).__dict__ args["tags"] = args["tags"].split(",") diff --git a/tgbot/shit/prompting.py b/tgbot/shit/prompting.py new file mode 100644 index 0000000..c426738 --- /dev/null +++ b/tgbot/shit/prompting.py @@ -0,0 +1,213 @@ +import argparse +import asyncio +import json +import random +import shlex +import threading + +import aiohttp +import telegram +import websockets +from telegram import Update +from telegram.ext import ContextTypes + +from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL +from tgbot.shit.custom_argparser import CustomArgParser +from tgbot.shit.handlers import escape_markdown + + +def filename2cdn(filename: str) -> str | None: + if not filename: + return None + return CDN_URL + filename + + +scheduled_prompts: dict = {} + + +async def ws_task(): + ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/ws?token={IMAGE_GENERATOR_TOKEN}" + retry_delay = 1 # Initial delay (1s), will increase with backoff + + while True: + try: + print(f"Connecting to WebSocket...") + async with websockets.connect(ws_url) as ws: + retry_delay = 1 # Reset delay on successful connection + + while True: + message = await ws.recv() + message_dict = json.loads(message) + + prompt_id = message_dict["id"] + image_url = filename2cdn(message_dict["image_filename"]) + chat_id = scheduled_prompts.get(prompt_id) + + print(f"Received message: {message}") + + bot = telegram.Bot(TOKEN) + + async with bot: + await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message)) + + del scheduled_prompts[prompt_id] + + except websockets.ConnectionClosed as e: + print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...") + except Exception as e: + print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...") + + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s + + +def start_ws_task(): + """Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot.""" + print("starting the websocket task") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(ws_task()) + + +thread = threading.Thread(target=start_ws_task, daemon=True) +thread.start() + + +async def enqueue_prompt_t2i(args: dict): + url = f"https://{IMAGE_GENERATOR_API_HOST}/prompt/generic/txt2img/" + async with aiohttp.ClientSession() as session: + headers = {"Authorization": + "Bearer " + IMAGE_GENERATOR_TOKEN + } + + result = await session.post( + url=url, + headers=headers, + json=args, + ) + + return await result.json() + + +def sdxl_image_dimension(value): + try: + value = int(value) + if value % 8 != 0: + raise argparse.ArgumentTypeError(f"{value} is not divisible by 8") + return value + except ValueError: + raise argparse.ArgumentTypeError(f"{value} is not an integer") + + +def denoise(value): + try: + value = float(value) + if value > 1.0 or value < 0: + raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level") + return value + except ValueError: + raise argparse.ArgumentTypeError(f"{value} is not a float") + + +def create_prompting_parser(): + parser = CustomArgParser( + description="prompting - a command for interacting with image generators", + usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad + and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler' + --scheduler=normal --denoise=1.0""", + ) + + parser.add_argument("--checkpoint", type=str, + choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"], + help="checkpoint to use", required=True) + + parser.add_argument("--prompt", type=str, + help="positive prompt, e.g. 'fox lady'", required=True) + + parser.add_argument("--negative_prompt", type=str, + help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False) + + parser.add_argument("--width", type=sdxl_image_dimension, + help="width, default=1024", default=1024) + + parser.add_argument("--height", type=sdxl_image_dimension, + help="height, default=1024", default=1024) + + parser.add_argument("--seed", type=int, + help="seed for noize generation, default is random", default=random.randint(0, 99999999)) + + parser.add_argument("--steps", type=int, + help="denoise steps, default=20", default=20) + + parser.add_argument("--cfg_scale", "-cfg", type=float, + help="classifier-free guidance scale, default=8", default=8.0) + + parser.add_argument("--sampler_name", type=str, + choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun', + 'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', + 'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu', + 'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde', + 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc', + 'uni_pc_bh2'], + help="sampler name, default is euler", default="euler") + + parser.add_argument("--scheduler", type=str, + choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'], + help="noise scheduler, default is normal", default="normal") + + parser.add_argument("--denoise", type=denoise, + help="denoise level, default is 1, default=8", default=1.0) + + return parser + + +def prompting_parse_args(args: str, parser: CustomArgParser): + args = parser.parse_args(shlex.split(args)).__dict__ + return args + + +async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE): + query = update.message.text.replace("prompting", "").replace("—", "--").replace("\n", "") + + parser = create_prompting_parser() + + try: + + args = prompting_parse_args(query, parser) + + except RuntimeError as e: + + if getattr(parser, 'help_message', None): + help_message = parser.help_message + + escaped_help_message = escape_markdown(help_message) + + return await context.bot.send_message(chat_id=update.effective_chat.id, + text="```\n" + escaped_help_message + "\n```", + parse_mode="MarkdownV2") + else: + + error_message = str(e) + + escaped_error_message = escape_markdown(error_message) + + return await context.bot.send_message(chat_id=update.effective_chat.id, + text="```\n" + escaped_error_message + "\n```", + parse_mode="MarkdownV2") + + result = await enqueue_prompt_t2i(args) + + if result["error"] == 0: + + prompt_id = result['data'] + + scheduled_prompts[prompt_id] = update.effective_chat.id + + print(scheduled_prompts) + + return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}") + + else: + return await context.bot.send_message(chat_id=update.effective_chat.id, + text="Some weird incomprehensible error occurred :<") +