add prompting
This commit is contained in:
parent
a6d58ebf6b
commit
2bb1e2dccb
@ -1,8 +1,10 @@
|
|||||||
FROM python:3.11
|
FROM python:3.11
|
||||||
|
|
||||||
|
COPY tgbot/requirements.txt .
|
||||||
|
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
COPY tgbot tgbot
|
COPY tgbot tgbot
|
||||||
COPY bot.py bot.py
|
COPY bot.py bot.py
|
||||||
|
|
||||||
RUN pip install -r tgbot/requirements.txt
|
|
||||||
|
|
||||||
CMD python3 bot.py
|
CMD python3 bot.py
|
||||||
|
17
bot.py
17
bot.py
@ -5,6 +5,7 @@ import logging
|
|||||||
|
|
||||||
from tgbot.config import TOKEN
|
from tgbot.config import TOKEN
|
||||||
from tgbot.shit.handlers import handle_xitter, handle_red_ebalo, handle_hentai, handle_cute_button
|
from tgbot.shit.handlers import handle_xitter, handle_red_ebalo, handle_hentai, handle_cute_button
|
||||||
|
from tgbot.shit.prompting import gen_image
|
||||||
from tgbot.shit.render import render_text_on_image, image_to_string
|
from tgbot.shit.render import render_text_on_image, image_to_string
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -20,9 +21,13 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|||||||
'AND I AM A PERSON OF COMPLETE AND TOTAL DELUSION!')
|
'AND I AM A PERSON OF COMPLETE AND TOTAL DELUSION!')
|
||||||
|
|
||||||
|
|
||||||
async def respond_with_picture(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def handle_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text)
|
|
||||||
await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file)
|
if not update.message.text.startswith("prompting"):
|
||||||
|
file = render_text_on_image("tgbot/assets/red_ebalo.png", update.message.text)
|
||||||
|
await context.bot.send_sticker(chat_id=update.effective_chat.id, sticker=file)
|
||||||
|
else:
|
||||||
|
await gen_image(update, context)
|
||||||
|
|
||||||
|
|
||||||
async def image2string(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def image2string(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
@ -52,7 +57,6 @@ async def callback_query_handler(update: Update, context: CallbackContext):
|
|||||||
else:
|
else:
|
||||||
await query.answer()
|
await query.answer()
|
||||||
|
|
||||||
|
|
||||||
async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
query = update.inline_query.query
|
query = update.inline_query.query
|
||||||
if not query:
|
if not query:
|
||||||
@ -72,17 +76,16 @@ async def handle_inline(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|||||||
|
|
||||||
await handle_red_ebalo(update, context)
|
await handle_red_ebalo(update, context)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
application = ApplicationBuilder().token(TOKEN).build()
|
application = ApplicationBuilder().token(TOKEN).build()
|
||||||
|
|
||||||
start_handler = CommandHandler('start', start)
|
start_handler = CommandHandler('start', start)
|
||||||
picture_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, respond_with_picture)
|
text_handler = MessageHandler(filters.TEXT & ~filters.COMMAND, handle_text)
|
||||||
|
|
||||||
application.add_handler(start_handler)
|
application.add_handler(start_handler)
|
||||||
|
|
||||||
application.add_handler(MessageHandler(filters.PHOTO, image2string))
|
application.add_handler(MessageHandler(filters.PHOTO, image2string))
|
||||||
application.add_handler(picture_handler)
|
application.add_handler(text_handler)
|
||||||
application.add_handler(InlineQueryHandler(handle_inline))
|
application.add_handler(InlineQueryHandler(handle_inline))
|
||||||
application.add_handler(CallbackQueryHandler(callback_query_handler))
|
application.add_handler(CallbackQueryHandler(callback_query_handler))
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ services:
|
|||||||
- 8.8.8.8
|
- 8.8.8.8
|
||||||
|
|
||||||
owlrandomshitbot:
|
owlrandomshitbot:
|
||||||
|
restart: unless-stopped
|
||||||
image: owlrandomshitbot
|
image: owlrandomshitbot
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
@ -28,6 +29,7 @@ services:
|
|||||||
- vpn
|
- vpn
|
||||||
|
|
||||||
booru-api:
|
booru-api:
|
||||||
|
restart: unless-stopped
|
||||||
image: booru-api
|
image: booru-api
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
|
@ -8,3 +8,7 @@ STICKER_SHITHOLE = -1002471390283
|
|||||||
|
|
||||||
DANBOORU_USERNAME = "owlrandomshitbot"
|
DANBOORU_USERNAME = "owlrandomshitbot"
|
||||||
DANBOORU_API_KEY = os.getenv("DANBOORU_API_KEY")
|
DANBOORU_API_KEY = os.getenv("DANBOORU_API_KEY")
|
||||||
|
|
||||||
|
IMAGE_GENERATOR_TOKEN: str = os.getenv("IMAGE_GENERATOR_TOKEN")
|
||||||
|
IMAGE_GENERATOR_API_HOST: str = os.getenv("IMAGE_GENERATOR_API_HOST")
|
||||||
|
CDN_URL: str = os.getenv("CDN_URL")
|
||||||
|
@ -8,4 +8,6 @@ pillow==11.0.0
|
|||||||
python-telegram-bot==21.6
|
python-telegram-bot==21.6
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
Pybooru==4.2.2
|
Pybooru==4.2.2
|
||||||
|
aiohttp==3.11.13
|
||||||
|
websockets==15.0
|
35
tgbot/shit/custom_argparser.py
Normal file
35
tgbot/shit/custom_argparser.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
class CustomArgParser(argparse.ArgumentParser):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args, **kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.help_message = None
|
||||||
|
|
||||||
|
def print_help(self, file=None):
|
||||||
|
# Store the help message in a buffer instead of printing
|
||||||
|
if file is None:
|
||||||
|
help_buffer = io.StringIO()
|
||||||
|
super().print_help(file=help_buffer)
|
||||||
|
self.help_message = help_buffer.getvalue()
|
||||||
|
help_buffer.close()
|
||||||
|
else:
|
||||||
|
super().print_help(file=file)
|
||||||
|
|
||||||
|
def parse_args(self, args=None, namespace=None):
|
||||||
|
# Check for --help manually to avoid stdout output
|
||||||
|
if '--help' in args:
|
||||||
|
self.print_help()
|
||||||
|
raise RuntimeError("Help requested")
|
||||||
|
return super().parse_args(args, namespace)
|
||||||
|
|
||||||
|
def exit(self, status=0, message=None):
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
def error(self, message):
|
||||||
|
raise RuntimeError(f"Error: {message}")
|
||||||
|
|
@ -7,7 +7,7 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, InlineQ
|
|||||||
from telegram.ext import ContextTypes
|
from telegram.ext import ContextTypes
|
||||||
|
|
||||||
from tgbot.config import STICKER_SHITHOLE, BOORU_API_URL, INLINE_QUERY_CACHE_SECONDS, DANBOORU_API_KEY
|
from tgbot.config import STICKER_SHITHOLE, BOORU_API_URL, INLINE_QUERY_CACHE_SECONDS, DANBOORU_API_KEY
|
||||||
from tgbot.shit.hentai_argparser import parse_args, create_parser
|
from tgbot.shit.hentai_argparser import hentai_parse_args, create_hentai_parser
|
||||||
from tgbot.shit.render import render_text_on_image
|
from tgbot.shit.render import render_text_on_image
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -74,11 +74,11 @@ async def handle_cute_button(update: Update, context: ContextTypes.DEFAULT_TYPE)
|
|||||||
async def handle_hentai(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def handle_hentai(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
query = update.inline_query.query.replace("hentai", "").replace("—", "--")
|
query = update.inline_query.query.replace("hentai", "").replace("—", "--")
|
||||||
|
|
||||||
parser = create_parser()
|
parser = create_hentai_parser()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
args = parse_args(query, parser)
|
args = hentai_parse_args(query, parser)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|
||||||
|
@ -1,41 +1,9 @@
|
|||||||
import argparse
|
|
||||||
import io
|
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
|
from tgbot.shit.custom_argparser import CustomArgParser
|
||||||
class CustomArgParser(argparse.ArgumentParser):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args, **kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.help_message = None
|
|
||||||
|
|
||||||
def print_help(self, file=None):
|
|
||||||
# Store the help message in a buffer instead of printing
|
|
||||||
if file is None:
|
|
||||||
help_buffer = io.StringIO()
|
|
||||||
super().print_help(file=help_buffer)
|
|
||||||
self.help_message = help_buffer.getvalue()
|
|
||||||
help_buffer.close()
|
|
||||||
else:
|
|
||||||
super().print_help(file=file)
|
|
||||||
|
|
||||||
def parse_args(self, args=None, namespace=None):
|
|
||||||
# Check for --help manually to avoid stdout output
|
|
||||||
if '--help' in args:
|
|
||||||
self.print_help()
|
|
||||||
raise RuntimeError("Help requested")
|
|
||||||
return super().parse_args(args, namespace)
|
|
||||||
|
|
||||||
def exit(self, status=0, message=None):
|
|
||||||
raise RuntimeError(message)
|
|
||||||
|
|
||||||
def error(self, message):
|
|
||||||
raise RuntimeError(f"Error: {message}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_parser():
|
def create_hentai_parser():
|
||||||
parser = CustomArgParser(
|
parser = CustomArgParser(
|
||||||
description="A command to search random stuff on various image boorus. Feel free to use it irresponsibly!",
|
description="A command to search random stuff on various image boorus. Feel free to use it irresponsibly!",
|
||||||
usage="@owlrandomshitbot hentai BOORU_NAME --tags TAG1,TAG2,TAG3 -l LIMIT --random -s",
|
usage="@owlrandomshitbot hentai BOORU_NAME --tags TAG1,TAG2,TAG3 -l LIMIT --random -s",
|
||||||
@ -68,7 +36,7 @@ def create_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def parse_args(args: str, parser: CustomArgParser):
|
def hentai_parse_args(args: str, parser: CustomArgParser):
|
||||||
args = parser.parse_args(shlex.split(args)).__dict__
|
args = parser.parse_args(shlex.split(args)).__dict__
|
||||||
|
|
||||||
args["tags"] = args["tags"].split(",")
|
args["tags"] = args["tags"].split(",")
|
||||||
|
213
tgbot/shit/prompting.py
Normal file
213
tgbot/shit/prompting.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import shlex
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import telegram
|
||||||
|
import websockets
|
||||||
|
from telegram import Update
|
||||||
|
from telegram.ext import ContextTypes
|
||||||
|
|
||||||
|
from tgbot.config import TOKEN, IMAGE_GENERATOR_API_HOST, IMAGE_GENERATOR_TOKEN, CDN_URL
|
||||||
|
from tgbot.shit.custom_argparser import CustomArgParser
|
||||||
|
from tgbot.shit.handlers import escape_markdown
|
||||||
|
|
||||||
|
|
||||||
|
def filename2cdn(filename: str) -> str | None:
|
||||||
|
if not filename:
|
||||||
|
return None
|
||||||
|
return CDN_URL + filename
|
||||||
|
|
||||||
|
|
||||||
|
scheduled_prompts: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_task():
|
||||||
|
ws_url = f"wss://{IMAGE_GENERATOR_API_HOST}/ws?token={IMAGE_GENERATOR_TOKEN}"
|
||||||
|
retry_delay = 1 # Initial delay (1s), will increase with backoff
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(f"Connecting to WebSocket...")
|
||||||
|
async with websockets.connect(ws_url) as ws:
|
||||||
|
retry_delay = 1 # Reset delay on successful connection
|
||||||
|
|
||||||
|
while True:
|
||||||
|
message = await ws.recv()
|
||||||
|
message_dict = json.loads(message)
|
||||||
|
|
||||||
|
prompt_id = message_dict["id"]
|
||||||
|
image_url = filename2cdn(message_dict["image_filename"])
|
||||||
|
chat_id = scheduled_prompts.get(prompt_id)
|
||||||
|
|
||||||
|
print(f"Received message: {message}")
|
||||||
|
|
||||||
|
bot = telegram.Bot(TOKEN)
|
||||||
|
|
||||||
|
async with bot:
|
||||||
|
await bot.send_photo(chat_id=chat_id, photo=image_url, caption=str(message))
|
||||||
|
|
||||||
|
del scheduled_prompts[prompt_id]
|
||||||
|
|
||||||
|
except websockets.ConnectionClosed as e:
|
||||||
|
print(f"WebSocket connection closed: {e}. Reconnecting in {retry_delay}s...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket error: {e}. Reconnecting in {retry_delay}s...")
|
||||||
|
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
retry_delay = min(retry_delay * 2, 60) # Exponential backoff, max 60s
|
||||||
|
|
||||||
|
|
||||||
|
def start_ws_task():
|
||||||
|
"""Run WebSocket listener in a separate thread to avoid blocking FastAPI & Telegram bot."""
|
||||||
|
print("starting the websocket task")
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(ws_task())
|
||||||
|
|
||||||
|
|
||||||
|
thread = threading.Thread(target=start_ws_task, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue_prompt_t2i(args: dict):
|
||||||
|
url = f"https://{IMAGE_GENERATOR_API_HOST}/prompt/generic/txt2img/"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
headers = {"Authorization":
|
||||||
|
"Bearer " + IMAGE_GENERATOR_TOKEN
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await session.post(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await result.json()
|
||||||
|
|
||||||
|
|
||||||
|
def sdxl_image_dimension(value):
|
||||||
|
try:
|
||||||
|
value = int(value)
|
||||||
|
if value % 8 != 0:
|
||||||
|
raise argparse.ArgumentTypeError(f"{value} is not divisible by 8")
|
||||||
|
return value
|
||||||
|
except ValueError:
|
||||||
|
raise argparse.ArgumentTypeError(f"{value} is not an integer")
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(value):
|
||||||
|
try:
|
||||||
|
value = float(value)
|
||||||
|
if value > 1.0 or value < 0:
|
||||||
|
raise argparse.ArgumentTypeError(f"{value} is not a valid denoise level")
|
||||||
|
return value
|
||||||
|
except ValueError:
|
||||||
|
raise argparse.ArgumentTypeError(f"{value} is not a float")
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompting_parser():
|
||||||
|
parser = CustomArgParser(
|
||||||
|
description="prompting - a command for interacting with image generators",
|
||||||
|
usage="""prompting --checkpoint=Illustrious-XL-v1.0.safetensors --prompt='fox lady' --negative_prompt='bad
|
||||||
|
and retarded' --width=1024 --height=1024 --seed=123456 --steps=25 --cfg_scale=7.5 --sampler_name='euler'
|
||||||
|
--scheduler=normal --denoise=1.0""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--checkpoint", type=str,
|
||||||
|
choices=["Illustrious-XL-v1.0.safetensors", "PVCStyleModelMovable_illustriousxl10.safetensors"],
|
||||||
|
help="checkpoint to use", required=True)
|
||||||
|
|
||||||
|
parser.add_argument("--prompt", type=str,
|
||||||
|
help="positive prompt, e.g. 'fox lady'", required=True)
|
||||||
|
|
||||||
|
parser.add_argument("--negative_prompt", type=str,
|
||||||
|
help="negative prompt, e.g. 'bad, huge dick'. Could be none", default="", required=False)
|
||||||
|
|
||||||
|
parser.add_argument("--width", type=sdxl_image_dimension,
|
||||||
|
help="width, default=1024", default=1024)
|
||||||
|
|
||||||
|
parser.add_argument("--height", type=sdxl_image_dimension,
|
||||||
|
help="height, default=1024", default=1024)
|
||||||
|
|
||||||
|
parser.add_argument("--seed", type=int,
|
||||||
|
help="seed for noize generation, default is random", default=random.randint(0, 99999999))
|
||||||
|
|
||||||
|
parser.add_argument("--steps", type=int,
|
||||||
|
help="denoise steps, default=20", default=20)
|
||||||
|
|
||||||
|
parser.add_argument("--cfg_scale", "-cfg", type=float,
|
||||||
|
help="classifier-free guidance scale, default=8", default=8.0)
|
||||||
|
|
||||||
|
parser.add_argument("--sampler_name", type=str,
|
||||||
|
choices=['euler', 'euler_cfg_pp', 'euler_ancestral', 'euler_ancestral_cfg_pp', 'heun',
|
||||||
|
'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
|
||||||
|
'dpmpp_2s_ancestral', 'dpmpp_2s_ancestral_cfg_pp', 'dpmpp_sde', 'dpmpp_sde_gpu',
|
||||||
|
'dpmpp_2m', 'dpmpp_2m_cfg_pp', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde',
|
||||||
|
'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ipndm', 'ipndm_v', 'deis', 'ddim', 'uni_pc',
|
||||||
|
'uni_pc_bh2'],
|
||||||
|
help="sampler name, default is euler", default="euler")
|
||||||
|
|
||||||
|
parser.add_argument("--scheduler", type=str,
|
||||||
|
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'beta'],
|
||||||
|
help="noise scheduler, default is normal", default="normal")
|
||||||
|
|
||||||
|
parser.add_argument("--denoise", type=denoise,
|
||||||
|
help="denoise level, default is 1, default=8", default=1.0)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def prompting_parse_args(args: str, parser: CustomArgParser):
|
||||||
|
args = parser.parse_args(shlex.split(args)).__dict__
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
query = update.message.text.replace("prompting", "").replace("—", "--").replace("\n", "")
|
||||||
|
|
||||||
|
parser = create_prompting_parser()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
args = prompting_parse_args(query, parser)
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
|
||||||
|
if getattr(parser, 'help_message', None):
|
||||||
|
help_message = parser.help_message
|
||||||
|
|
||||||
|
escaped_help_message = escape_markdown(help_message)
|
||||||
|
|
||||||
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
||||||
|
text="```\n" + escaped_help_message + "\n```",
|
||||||
|
parse_mode="MarkdownV2")
|
||||||
|
else:
|
||||||
|
|
||||||
|
error_message = str(e)
|
||||||
|
|
||||||
|
escaped_error_message = escape_markdown(error_message)
|
||||||
|
|
||||||
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
||||||
|
text="```\n" + escaped_error_message + "\n```",
|
||||||
|
parse_mode="MarkdownV2")
|
||||||
|
|
||||||
|
result = await enqueue_prompt_t2i(args)
|
||||||
|
|
||||||
|
if result["error"] == 0:
|
||||||
|
|
||||||
|
prompt_id = result['data']
|
||||||
|
|
||||||
|
scheduled_prompts[prompt_id] = update.effective_chat.id
|
||||||
|
|
||||||
|
print(scheduled_prompts)
|
||||||
|
|
||||||
|
return await context.bot.send_message(chat_id=update.effective_chat.id, text=f"prompt_id: {prompt_id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
return await context.bot.send_message(chat_id=update.effective_chat.id,
|
||||||
|
text="Some weird incomprehensible error occurred :<")
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user