add prompting

This commit is contained in:
Owl 2025-03-03 04:10:37 +07:00
parent a6d58ebf6b
commit 2bb1e2dccb
9 changed files with 277 additions and 48 deletions

View File

@ -1,8 +1,10 @@
FROM python:3.11 FROM python:3.11
COPY tgbot/requirements.txt .
RUN pip install -r requirements.txt
COPY tgbot tgbot COPY tgbot tgbot
COPY bot.py bot.py COPY bot.py bot.py
RUN pip install -r tgbot/requirements.txt
CMD python3 bot.py CMD python3 bot.py

13
bot.py
View File

@ -5,6 +5,7 @@ import logging
from tgbot.config import TOKEN from tgbot.config import TOKEN
from tgbot.shit.handlers import handle_xitter, handle_red_ebalo, handle_hentai, handle_cute_button from tgbot.shit.handlers import handle_xitter, handle_red_ebalo, handle_hentai, handle_cute_button
from tgbot.shit.prompting import gen_image
from tgbot.shit.render import render_text_on_image, image_to_string from tgbot.shit.render import render_text_on_image, image_to_string
from io import BytesIO from io import BytesIO
@ -20,9 +21,13 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
'AND I AM A PERSON OF COMPLETE AND TOTAL DELUSION!') 'AND I AM A PERSON OF COMPLETE AND TOTAL DELUSION!')
async def respond_with_picture(update: Update, context: ContextTypes.DEFAULT_TYPE): async def handle_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.message.text.startswith("prompting"):
file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text) file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text)
await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file) await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file)
else:
await gen_image(update, context)
async def image2string(update: Update, context: ContextTypes.DEFAULT_TYPE): async def image2string(update: Update, context: ContextTypes.DEFAULT_TYPE):
@ -52,7 +57,6 @@ async def callback_query_handler(update: Update, context: CallbackContext):
else: else:
await query.answer() await query.answer()
async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE): async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.inline_query.query query = update.inline_query.query
if not query: if not query:
@ -72,17 +76,16 @@ async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE):
await handle_red_ebalo(update, context) await handle_red_ebalo(update, context)
def main(): def main():
application = ApplicationBuilder().token(TOKEN).build() application = ApplicationBuilder().token(TOKEN).build()
start_handler = CommandHandler('start', start) start_handler = CommandHandler('start', start)
picture_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, respond_with_picture) text_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, handle_text)
application.add_handler(start_handler) application.add_handler(start_handler)
application.add_handler(MessageHandler(filters.PHOTO, image2string)) application.add_handler(MessageHandler(filters.PHOTO, image2string))
application.add_handler(picture_handler) application.add_handler(text_handler)
application.add_handler(InlineQueryHandler(handle_inline)) application.add_handler(InlineQueryHandler(handle_inline))
application.add_handler(CallbackQueryHandler(callback_query_handler)) application.add_handler(CallbackQueryHandler(callback_query_handler))

View File

@ -18,6 +18,7 @@ services:
- 8.8.8.8 - 8.8.8.8
owlrandomshitbot: owlrandomshitbot:
restart: unless-stopped
image: owlrandomshitbot image: owlrandomshitbot
build: build:
context: . context: .
@ -28,6 +29,7 @@ services:
- vpn - vpn
booru-api: booru-api:
restart: unless-stopped
image: booru-api image: booru-api
build: build:
context: . context: .

View File

@ -8,3 +8,7 @@ STICKER_SHITHOLE = -1002471390283
DANBOORU_USERNAME = "owlrandomshitbot" DANBOORU_USERNAME = "owlrandomshitbot"
DANBOORU_API_KEY = os.getenv("DANBOORU_API_KEY") DANBOORU_API_KEY = os.getenv("DANBOORU_API_KEY")
IMAGE_GENERATOR_TOKEN: str = os.getenv("IMAGE_GENERATOR_TOKEN")
IMAGE_GENERATOR_API_HOST: str = os.getenv("IMAGE_GENERATOR_API_HOST")
CDN_URL: str = os.getenv("CDN_URL")

View File

@ -9,3 +9,5 @@ python-telegram-bot==21.6
sniffio==1.3.1 sniffio==1.3.1
requests==2.32.3 requests==2.32.3
Pybooru==4.2.2 Pybooru==4.2.2
aiohttp==3.11.13
websockets==15.0

View File

@ -0,0 +1,35 @@
import argparse
import io
class CustomArgParser(argparse.ArgumentParser):
def __init__(
self,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.help_message = None
def print_help(self, file=None):
# Store the help message in a buffer instead of printing
if file is None:
help_buffer = io.StringIO()
super().print_help(file=help_buffer)
self.help_message = help_buffer.getvalue()
help_buffer.close()
else:
super().print_help(file=file)
def parse_args(self, args=None, namespace=None):
# Check for --help manually to avoid stdout output
if '--help' in args:
self.print_help()
raise RuntimeError("Help requested")
return super().parse_args(args, namespace)
def exit(self, status=0, message=None):
raise RuntimeError(message)
def error(self, message):
raise RuntimeError(f"Error: {message}")

View File

@ -7,7 +7,7 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, InlineQ
from telegram.ext import ContextTypes from telegram.ext import ContextTypes
from tgbot.config import STICKER_SHITHOLE, BOORU_API_URL, INLINE_QUERY_CACHE_SECONDS, DANBOORU_API_KEY from tgbot.config import STICKER_SHITHOLE, BOORU_API_URL, INLINE_QUERY_CACHE_SECONDS, DANBOORU_API_KEY
from tgbot.shit.hentai_argparser import parse_args, create_parser from tgbot.shit.hentai_argparser import hentai_parse_args, create_hentai_parser
from tgbot.shit.render import render_text_on_image from tgbot.shit.render import render_text_on_image
import re import re
@ -74,11 +74,11 @@ async def handle_cute_button(update: Update, context: ContextTypes.DEFAULT_TYPE)
async def handle_hentai(update: Update, context: ContextTypes.DEFAULT_TYPE): async def handle_hentai(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.inline_query.query.replace("hentai", "").replace("", "--") query = update.inline_query.query.replace("hentai", "").replace("", "--")
parser = create_parser() parser = create_hentai_parser()
try: try:
args = parse_args(query, parser) args = hentai_parse_args(query, parser)
except RuntimeError as e: except RuntimeError as e:

View File

@ -1,41 +1,9 @@
import argparse
import io
import shlex import shlex
from tgbot.shit.custom_argparser import CustomArgParser
class CustomArgParser(argparse.ArgumentParser):
def __init__(
self,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.help_message = None
def print_help(self, file=None):
# Store the help message in a buffer instead of printing
if file is None:
help_buffer = io.StringIO()
super().print_help(file=help_buffer)
self.help_message = help_buffer.getvalue()
help_buffer.close()
else:
super().print_help(file=file)
def parse_args(self, args=None, namespace=None):
# Check for --help manually to avoid stdout output
if '--help' in args:
self.print_help()
raise RuntimeError("Help requested")
return super().parse_args(args, namespace)
def exit(self, status=0, message=None):
raise RuntimeError(message)
def error(self, message):
raise RuntimeError(f"Error: {message}")
def create_parser(): def create_hentai_parser():
parser = CustomArgParser( parser = CustomArgParser(
description="A command to search random stuff on various image boorus. Feel free to use it irresponsibly!", description="A command to search random stuff on various image boorus. Feel free to use it irresponsibly!",
usage="@owlrandomshitbot hentai BOORU_NAME --tags TAG1,TAG2,TAG3 -l LIMIT --random -s", usage="@owlrandomshitbot hentai BOORU_NAME --tags TAG1,TAG2,TAG3 -l LIMIT --random -s",
@ -68,7 +36,7 @@ def create_parser():
return parser return parser
def parse_args(args: str, parser: CustomArgParser): def hentai_parse_args(args: str, parser: CustomArgParser):
args = parser.parse_args(shlex.split(args)).__dict__ args = parser.parse_args(shlex.split(args)).__dict__
args["tags"] = args["tags"].split(",") args["tags"] = args["tags"].split(",")

213
tgbot/shit/prompting.py Normal file
View File

@ -0,0 +1,213 @@
import argparse
import asyncio
import json
import random
import shlex
import threading
import aiohttp
import telegram
import websockets
from telegram import Update
from telegram.ext import ContextTypes
from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL
from tgbot.shit.custom_argparser import CustomArgParser
from tgbot.shit.handlers import escape_markdown
def filename2cdn(filename: str) -> str | None:
if not filename:
return None
return CDN_URL + filename
scheduled_prompts: dict = {}
async def ws_task():
ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/ws?token={IMAGE_GENERATOR_TOKEN}"
retry_delay = 1 # Initial delay (1s), will increase with backoff
while True:
try:
print(f"Connecting to WebSocket...")
async with websockets.connect(ws_url) as ws:
retry_delay = 1 # Reset delay on successful connection
while True:
message = await ws.recv()
message_dict = json.loads(message)
prompt_id = message_dict["id"]
image_url = filename2cdn(message_dict["image_filename"])
chat_id = scheduled_prompts.get(prompt_id)
print(f"Received message: {message}")
bot = telegram.Bot(TOKEN)
async with bot:
await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message))
del scheduled_prompts[prompt_id]
except websockets.ConnectionClosed as e:
print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...")
except Exception as e:
print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s
def start_ws_task():
"""Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot."""
print("starting the websocket task")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(ws_task())
thread = threading.Thread(target=start_ws_task, daemon=True)
thread.start()
async def enqueue_prompt_t2i(args: dict):
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompt/generic/txt2img/"
async with aiohttp.ClientSession() as session:
headers = {"Authorization":
"Bearer " + IMAGE_GENERATOR_TOKEN
}
result = await session.post(
url=url,
headers=headers,
json=args,
)
return await result.json()
def sdxl_image_dimension(value):
try:
value = int(value)
if value % 8 != 0:
raise argparse.ArgumentTypeError(f"{value} is not divisible by 8")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not an integer")
def denoise(value):
try:
value = float(value)
if value > 1.0 or value < 0:
raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level")
return value
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a float")
def create_prompting_parser():
parser = CustomArgParser(
description="prompting - a command for interacting with image generators",
usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad
and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler'
--scheduler=normal --denoise=1.0""",
)
parser.add_argument("--checkpoint", type=str,
choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"],
help="checkpoint to use", required=True)
parser.add_argument("--prompt", type=str,
help="positive prompt, e.g. 'fox lady'", required=True)
parser.add_argument("--negative_prompt", type=str,
help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False)
parser.add_argument("--width", type=sdxl_image_dimension,
help="width, default=1024", default=1024)
parser.add_argument("--height", type=sdxl_image_dimension,
help="height, default=1024", default=1024)
parser.add_argument("--seed", type=int,
help="seed for noize generation, default is random", default=random.randint(0, 99999999))
parser.add_argument("--steps", type=int,
help="denoise steps, default=20", default=20)
parser.add_argument("--cfg_scale", "-cfg", type=float,
help="classifier-free guidance scale, default=8", default=8.0)
parser.add_argument("--sampler_name", type=str,
choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun',
'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu',
'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde',
'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc',
'uni_pc_bh2'],
help="sampler name, default is euler", default="euler")
parser.add_argument("--scheduler", type=str,
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'],
help="noise scheduler, default is normal", default="normal")
parser.add_argument("--denoise", type=denoise,
help="denoise level, default is 1, default=8", default=1.0)
return parser
def prompting_parse_args(args: str, parser: CustomArgParser):
args = parser.parse_args(shlex.split(args)).__dict__
return args
async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.message.text.replace("prompting", "").replace("", "--").replace("\n", "")
parser = create_prompting_parser()
try:
args = prompting_parse_args(query, parser)
except RuntimeError as e:
if getattr(parser, 'help_message', None):
help_message = parser.help_message
escaped_help_message = escape_markdown(help_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_help_message + "\n```",
parse_mode="MarkdownV2")
else:
error_message = str(e)
escaped_error_message = escape_markdown(error_message)
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="```\n" + escaped_error_message + "\n```",
parse_mode="MarkdownV2")
result = await enqueue_prompt_t2i(args)
if result["error"] == 0:
prompt_id = result['data']
scheduled_prompts[prompt_id] = update.effective_chat.id
print(scheduled_prompts)
return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}")
else:
return await context.bot.send_message(chat_id=update.effective_chat.id,
text="Some weird incomprehensible error occurred :<")