Aftrr commited on
Commit
a1eaa82
Β·
verified Β·
1 Parent(s): c488b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -16
app.py CHANGED
@@ -1,28 +1,90 @@
 
1
  import gradio as gr
2
- from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
3
 
4
- # βœ… Better small chatbot model (normal replies)
5
  MODEL_ID = "facebook/blenderbot-400M-distill"
 
 
6
 
7
- tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_ID)
8
- model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_ID)
9
-
10
- def chat(history, message):
11
- # Convert chat history to context
12
- context = ""
13
- for user, bot in history:
14
- context += f"User: {user}\nBot: {bot}\n"
15
- context += f"User: {message}\nBot:"
16
 
 
 
17
  inputs = tokenizer(context, return_tensors="pt")
18
  reply_ids = model.generate(**inputs, max_length=120, no_repeat_ngram_size=2)
19
  reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
 
20
 
21
- history.append((message, reply))
22
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  with gr.Blocks() as demo:
25
- gr.Markdown("## πŸ€– Anuj's Chatbot (Now Smarter!)")
 
 
 
 
 
 
 
 
26
 
27
- chatbot = gr.Chatbot()
28
- state = gr
 
1
+ # app.py
2
  import gradio as gr
3
+ from typing import List, Tuple, Dict, Any
4
 
 
5
  MODEL_ID = "facebook/blenderbot-400M-distill"
6
+ model = None
7
+ tokenizer = None
8
 
9
+ def ensure_model_loaded():
10
+ global model, tokenizer
11
+ if model is None or tokenizer is None:
12
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
13
+ tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_ID)
14
+ model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_ID)
 
 
 
15
 
16
+ def generate_reply(context: str) -> str:
17
+ ensure_model_loaded()
18
  inputs = tokenizer(context, return_tensors="pt")
19
  reply_ids = model.generate(**inputs, max_length=120, no_repeat_ngram_size=2)
20
  reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
21
+ return reply
22
 
23
+ def history_to_context_from_tuples(history: List[Tuple[str, str]]) -> str:
24
+ ctx = ""
25
+ for u, b in history:
26
+ ctx += f"User: {u}\nBot: {b}\n"
27
+ return ctx
28
+
29
+ def history_to_context_from_messages(history: List[Dict[str, str]]) -> str:
30
+ # messages are like {"role":"user"/"assistant","content":"..."}
31
+ ctx = ""
32
+ for msg in history:
33
+ role = msg.get("role", "")
34
+ content = msg.get("content", "")
35
+ if role and content:
36
+ if role.lower().startswith("user"):
37
+ ctx += f"User: {content}\n"
38
+ else:
39
+ ctx += f"Bot: {content}\n"
40
+ return ctx
41
+
42
+ def chat(state: List[Any], message: str):
43
+ """
44
+ state is the Gradio chatbot state.
45
+ We support both:
46
+ - tuples: [("hi","hello"), ...]
47
+ - messages: [{"role":"user","content":"hi"}, {"role":"assistant","content":"hello"}, ...]
48
+ Gradio will pass state back as-is.
49
+ """
50
+ # detect format
51
+ context = ""
52
+ if state and isinstance(state[0], dict):
53
+ # messages format
54
+ context = history_to_context_from_messages(state)
55
+ elif state and isinstance(state[0], (list, tuple)):
56
+ # nested list/tuple format
57
+ # convert [(user,bot), ...] -> context
58
+ context = history_to_context_from_tuples(state)
59
+ else:
60
+ # empty or unknown -> fine
61
+ context = ""
62
+
63
+ context += f"User: {message}\nBot:"
64
+ reply = generate_reply(context)
65
+
66
+ # append to state in messages format (preferred)
67
+ # we'll append two entries: user then assistant
68
+ # If original state was tuples, convert reply to tuple for compatibility
69
+ if state and isinstance(state[0], dict):
70
+ state.append({"role":"user","content": message})
71
+ state.append({"role":"assistant","content": reply})
72
+ return state, state
73
+ else:
74
+ # use tuples format for backward compatibility
75
+ state.append((message, reply))
76
+ return state, state
77
 
78
  with gr.Blocks() as demo:
79
+ gr.Markdown("## πŸ€– Anuj's Chatbot β€” stable (messages format)")
80
+ # Use the new 'messages' type to avoid deprecation warning
81
+ chatbot = gr.Chatbot(type="messages")
82
+ state = gr.State([])
83
+
84
+ with gr.Row():
85
+ msg = gr.Textbox(show_label=False, placeholder="Type a message and press Enter...")
86
+
87
+ msg.submit(chat, [state, msg], [state, chatbot])
88
 
89
+ if __name__ == "__main__":
90
+ demo.launch(server_name="0.0.0.0", server_port=7860)