|
|
|
|
|
import gradio as gr |
|
|
from typing import List, Tuple, Dict, Any |
|
|
|
|
|
MODEL_ID = "facebook/blenderbot-400M-distill" |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
def ensure_model_loaded(): |
|
|
global model, tokenizer |
|
|
if model is None or tokenizer is None: |
|
|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
|
tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_ID) |
|
|
model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_ID) |
|
|
|
|
|
def generate_reply(context: str) -> str: |
|
|
ensure_model_loaded() |
|
|
inputs = tokenizer(context, return_tensors="pt") |
|
|
reply_ids = model.generate(**inputs, max_length=120, no_repeat_ngram_size=2) |
|
|
reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True) |
|
|
return reply |
|
|
|
|
|
def history_to_context_from_tuples(history: List[Tuple[str, str]]) -> str: |
|
|
ctx = "" |
|
|
for u, b in history: |
|
|
ctx += f"User: {u}\nBot: {b}\n" |
|
|
return ctx |
|
|
|
|
|
def history_to_context_from_messages(history: List[Dict[str, str]]) -> str: |
|
|
|
|
|
ctx = "" |
|
|
for msg in history: |
|
|
role = msg.get("role", "") |
|
|
content = msg.get("content", "") |
|
|
if role and content: |
|
|
if role.lower().startswith("user"): |
|
|
ctx += f"User: {content}\n" |
|
|
else: |
|
|
ctx += f"Bot: {content}\n" |
|
|
return ctx |
|
|
|
|
|
def chat(state: List[Any], message: str): |
|
|
""" |
|
|
state is the Gradio chatbot state. |
|
|
We support both: |
|
|
- tuples: [("hi","hello"), ...] |
|
|
- messages: [{"role":"user","content":"hi"}, {"role":"assistant","content":"hello"}, ...] |
|
|
Gradio will pass state back as-is. |
|
|
""" |
|
|
|
|
|
context = "" |
|
|
if state and isinstance(state[0], dict): |
|
|
|
|
|
context = history_to_context_from_messages(state) |
|
|
elif state and isinstance(state[0], (list, tuple)): |
|
|
|
|
|
|
|
|
context = history_to_context_from_tuples(state) |
|
|
else: |
|
|
|
|
|
context = "" |
|
|
|
|
|
context += f"User: {message}\nBot:" |
|
|
reply = generate_reply(context) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if state and isinstance(state[0], dict): |
|
|
state.append({"role":"user","content": message}) |
|
|
state.append({"role":"assistant","content": reply}) |
|
|
return state, state |
|
|
else: |
|
|
|
|
|
state.append((message, reply)) |
|
|
return state, state |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## π€ Anuj's Chatbot β stable (messages format)") |
|
|
|
|
|
chatbot = gr.Chatbot(type="messages") |
|
|
state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox(show_label=False, placeholder="Type a message and press Enter...") |
|
|
|
|
|
msg.submit(chat, [state, msg], [state, chatbot]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|