Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import os
|
|
| 11 |
import time
|
| 12 |
import gc # メモリ解放
|
| 13 |
import re # 正規表現で文章をクリーンアップ
|
|
|
|
| 14 |
|
| 15 |
# HuggingFace
|
| 16 |
import torch
|
|
@@ -115,6 +116,55 @@ class SessionState:
|
|
| 115 |
|
| 116 |
self.cache_clear()
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# --------------------------------------
|
| 119 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
| 120 |
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
|
|
@@ -157,10 +207,21 @@ def name_detector(text: str) -> list:
|
|
| 157 |
|
| 158 |
node = node.next
|
| 159 |
|
| 160 |
-
|
|
|
|
| 161 |
|
| 162 |
return names
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
# --------------------------------------
|
| 165 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
| 166 |
# --------------------------------------
|
|
@@ -207,21 +268,12 @@ def deepl_memory(ss: SessionState) -> (SessionState):
|
|
| 207 |
def web_search(ss: SessionState, query) -> (SessionState, str):
|
| 208 |
|
| 209 |
search = DuckDuckGoSearchRun(verbose=True)
|
|
|
|
|
|
|
| 210 |
|
| 211 |
for i in range(3):
|
| 212 |
web_result = search(query)
|
| 213 |
|
| 214 |
-
# 人名の抽出
|
| 215 |
-
names = []
|
| 216 |
-
names.extend(name_detector(query))
|
| 217 |
-
names.extend(name_detector(web_result))
|
| 218 |
-
if len(names)==0:
|
| 219 |
-
names = ""
|
| 220 |
-
elif len(names)==1:
|
| 221 |
-
names = names[0]
|
| 222 |
-
else:
|
| 223 |
-
names = ", ".join(names)
|
| 224 |
-
|
| 225 |
if ss.current_model == "gpt-3.5-turbo":
|
| 226 |
text = [query, web_result]
|
| 227 |
params = {
|
|
@@ -235,21 +287,33 @@ def web_search(ss: SessionState, query) -> (SessionState, str):
|
|
| 235 |
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
| 236 |
response = request.json()
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
if
|
|
|
|
|
|
|
| 242 |
break
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
if names != "":
|
| 245 |
web_query = f"""
|
| 246 |
-
{
|
| 247 |
-
Use the following Suggested Answer
|
| 248 |
-
Suggested Answer
|
| 249 |
Names: {names}
|
| 250 |
""".strip()
|
| 251 |
else:
|
| 252 |
-
web_query =
|
| 253 |
|
| 254 |
|
| 255 |
return ss, web_query
|
|
@@ -265,29 +329,19 @@ def web_search(ss: SessionState, query) -> (SessionState, str):
|
|
| 265 |
# --------------------------------------
|
| 266 |
# Conversation Chain Template
|
| 267 |
# --------------------------------------
|
| 268 |
-
|
| 269 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
| 270 |
-
# sys_chat_message = """
|
| 271 |
-
# You are an outstanding AI concierge. Understand the intent of the customer's questions based on
|
| 272 |
-
# the conversation history. Then, answer them with many specific and detailed information in Japanese.
|
| 273 |
-
# If you do not know the answer to a question, do make up an answer and says
|
| 274 |
-
# "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 275 |
-
# """.replace("\n", "")
|
| 276 |
|
| 277 |
sys_chat_message = """
|
| 278 |
-
You are an
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
""".
|
| 283 |
|
| 284 |
chat_common_format = """
|
| 285 |
===
|
| 286 |
Question: {query}
|
| 287 |
-
|
| 288 |
-
Conversation History:
|
| 289 |
-
{chat_history}
|
| 290 |
-
===
|
| 291 |
日本語の回答: """
|
| 292 |
|
| 293 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
|
@@ -297,35 +351,46 @@ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common
|
|
| 297 |
# QA Chain Template (Stuff)
|
| 298 |
# --------------------------------------
|
| 299 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
| 300 |
-
sys_qa_message = """
|
| 301 |
-
You are an AI concierge who carefully answers questions from customers based on references.
|
| 302 |
-
|
| 303 |
-
a specific answer in Japanese using sentences extracted from the following references.
|
| 304 |
-
not know the answer, do not make up an answer and reply,
|
| 305 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
qa_common_format = """
|
| 308 |
===
|
| 309 |
Question: {query}
|
| 310 |
References: {context}
|
| 311 |
-
|
| 312 |
-
Conversation History:
|
| 313 |
-
{chat_history}
|
| 314 |
-
===
|
| 315 |
日本語の回答: """
|
| 316 |
|
|
|
|
|
|
|
| 317 |
|
| 318 |
-
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
| 319 |
-
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
| 320 |
|
| 321 |
# --------------------------------------
|
| 322 |
# QA Chain Template (Map Reduce)
|
| 323 |
# --------------------------------------
|
| 324 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
| 325 |
query_generator_message = """
|
| 326 |
-
Referring to the "Conversation History",
|
| 327 |
-
|
| 328 |
-
|
|
|
|
| 329 |
""".replace("\n", "")
|
| 330 |
|
| 331 |
query_generator_common_format = """
|
|
@@ -334,30 +399,25 @@ query_generator_common_format = """
|
|
| 334 |
{chat_history}
|
| 335 |
|
| 336 |
[Additional Question] {query}
|
| 337 |
-
|
| 338 |
|
| 339 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
| 340 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
| 341 |
|
| 342 |
|
| 343 |
# 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
|
| 344 |
-
# question_prompt_message = """
|
| 345 |
-
# From the following references, extract key information relevant to the question
|
| 346 |
-
# and summarize it in a natural English sentence with clear subject, verb, object,
|
| 347 |
-
# and complement. If there is no information in the reference that answers the question,
|
| 348 |
-
# do not summarize and simply answer "NO INFO"
|
| 349 |
-
# """.replace("\n", "")
|
| 350 |
|
| 351 |
question_prompt_message = """
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
|
|
|
| 355 |
|
| 356 |
question_prompt_common_format = """
|
| 357 |
===
|
| 358 |
[Question] {query}
|
| 359 |
-
[
|
| 360 |
-
[
|
| 361 |
|
| 362 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
| 363 |
question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
|
|
@@ -578,11 +638,12 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
| 578 |
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
|
| 579 |
|
| 580 |
if ss.conversation_chain is None:
|
| 581 |
-
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
|
|
|
| 582 |
ss.conversation_chain = ConversationChain(
|
| 583 |
llm = ss.llm,
|
| 584 |
prompt = chat_prompt,
|
| 585 |
-
memory = ss.memory,
|
| 586 |
input_key = "query",
|
| 587 |
output_key = "output_text",
|
| 588 |
verbose = True,
|
|
@@ -590,13 +651,16 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
| 590 |
|
| 591 |
if ss.qa_chain is None:
|
| 592 |
if summarization_mode == "stuff":
|
| 593 |
-
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
|
| 594 |
-
|
|
|
|
|
|
|
| 595 |
|
| 596 |
elif summarization_mode == "map_reduce":
|
| 597 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
| 598 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
| 599 |
-
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
|
|
|
| 600 |
|
| 601 |
if ss.web_summary_chain is None:
|
| 602 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
|
@@ -853,6 +917,8 @@ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (Sess
|
|
| 853 |
|
| 854 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
| 855 |
response = ss.conversation_chain.predict(query=query)
|
|
|
|
|
|
|
| 856 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 857 |
return ss
|
| 858 |
|
|
@@ -890,10 +956,12 @@ def qa_predict(ss: SessionState, query) -> SessionState:
|
|
| 890 |
if result["output_text"] != "":
|
| 891 |
response = result["output_text"] + sources
|
| 892 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
|
|
|
|
|
|
| 893 |
return ss
|
| 894 |
-
else:
|
| 895 |
# 空欄の場合は直近の履歴を削除してやり直し
|
| 896 |
-
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
|
| 897 |
|
| 898 |
# 3回の試行後も空欄の場合
|
| 899 |
response = "3回試行しましたが、情報製生成できませんでした。"
|
|
|
|
| 11 |
import time
|
| 12 |
import gc # メモリ解放
|
| 13 |
import re # 正規表現で文章をクリーンアップ
|
| 14 |
+
import regex # 漢字抽出で利用
|
| 15 |
|
| 16 |
# HuggingFace
|
| 17 |
import torch
|
|
|
|
| 116 |
|
| 117 |
self.cache_clear()
|
| 118 |
|
| 119 |
+
# --------------------------------------
|
| 120 |
+
# メモリを使用しない ConversationChainを自作
|
| 121 |
+
# --------------------------------------
|
| 122 |
+
from typing import Dict, List
|
| 123 |
+
|
| 124 |
+
from langchain.chains.conversation.prompt import PROMPT
|
| 125 |
+
from langchain.chains.llm import LLMChain
|
| 126 |
+
from langchain.pydantic_v1 import Extra, Field, root_validator
|
| 127 |
+
from langchain.schema import BasePromptTemplate
|
| 128 |
+
|
| 129 |
+
class ConversationChain(LLMChain):
|
| 130 |
+
"""Chain to have a conversation without loading context from memory.
|
| 131 |
+
|
| 132 |
+
Example:
|
| 133 |
+
.. code-block:: python
|
| 134 |
+
|
| 135 |
+
from langchain import ConversationChainWithoutMemory, OpenAI
|
| 136 |
+
|
| 137 |
+
conversation = ConversationChainWithoutMemory(llm=OpenAI())
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
prompt: BasePromptTemplate = PROMPT
|
| 141 |
+
"""Default conversation prompt to use."""
|
| 142 |
+
|
| 143 |
+
input_key: str = "input" #: :meta private:
|
| 144 |
+
output_key: str = "response" #: :meta private:
|
| 145 |
+
|
| 146 |
+
class Config:
|
| 147 |
+
"""Configuration for this pydantic object."""
|
| 148 |
+
|
| 149 |
+
extra = Extra.forbid
|
| 150 |
+
arbitrary_types_allowed = True
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def input_keys(self) -> List[str]:
|
| 154 |
+
"""Use this since so some prompt vars come from history."""
|
| 155 |
+
return [self.input_key]
|
| 156 |
+
|
| 157 |
+
@root_validator()
|
| 158 |
+
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
| 159 |
+
"""Validate that prompt input variables are consistent without memory."""
|
| 160 |
+
input_key = values["input_key"]
|
| 161 |
+
prompt_variables = values["prompt"].input_variables
|
| 162 |
+
if input_key not in prompt_variables:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"The prompt expects {prompt_variables}, but {input_key} is not found."
|
| 165 |
+
)
|
| 166 |
+
return values
|
| 167 |
+
|
| 168 |
# --------------------------------------
|
| 169 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
| 170 |
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
|
|
|
|
| 207 |
|
| 208 |
node = node.next
|
| 209 |
|
| 210 |
+
# ユニークな値を抽出し、その後漢字を含む値のみとする
|
| 211 |
+
names = filter_kanji(list(set(names)))
|
| 212 |
|
| 213 |
return names
|
| 214 |
|
| 215 |
+
# --------------------------------------
|
| 216 |
+
# リストから漢字を含む値だけを抽出する
|
| 217 |
+
# --------------------------------------
|
| 218 |
+
def filter_kanji(lst) -> list:
|
| 219 |
+
def contains_kanji(s):
|
| 220 |
+
p = regex.compile(r'\p{Script=Han}+')
|
| 221 |
+
return bool(p.search(s))
|
| 222 |
+
|
| 223 |
+
return [item for item in lst if contains_kanji(item)]
|
| 224 |
+
|
| 225 |
# --------------------------------------
|
| 226 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
| 227 |
# --------------------------------------
|
|
|
|
| 268 |
def web_search(ss: SessionState, query) -> (SessionState, str):
|
| 269 |
|
| 270 |
search = DuckDuckGoSearchRun(verbose=True)
|
| 271 |
+
names = []
|
| 272 |
+
names.extend(name_detector(query))
|
| 273 |
|
| 274 |
for i in range(3):
|
| 275 |
web_result = search(query)
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
if ss.current_model == "gpt-3.5-turbo":
|
| 278 |
text = [query, web_result]
|
| 279 |
params = {
|
|
|
|
| 287 |
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
| 288 |
response = request.json()
|
| 289 |
|
| 290 |
+
query_eng = response["translations"][0]["text"]
|
| 291 |
+
web_result_eng = response["translations"][1]["text"]
|
| 292 |
+
web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text']
|
| 293 |
+
if "$$NO INFO$$" in web_result_eng:
|
| 294 |
+
web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text']
|
| 295 |
+
if "$$NO INFO$$" not in web_result_eng:
|
| 296 |
break
|
| 297 |
|
| 298 |
+
# 検索結果から人名を抽出し、テキスト化
|
| 299 |
+
names.extend(name_detector(web_result))
|
| 300 |
+
if len(names)==0:
|
| 301 |
+
names = ""
|
| 302 |
+
elif len(names)==1:
|
| 303 |
+
names = names[0]
|
| 304 |
+
else:
|
| 305 |
+
names = ", ".join(names)
|
| 306 |
+
|
| 307 |
+
# Web検索結果を含むQueryを渡す。
|
| 308 |
if names != "":
|
| 309 |
web_query = f"""
|
| 310 |
+
{query_eng}
|
| 311 |
+
Use the following Suggested Answer as a reference to answer the question above in Japanese. When translating names of people, refer to Names as a translation guide.
|
| 312 |
+
Suggested Answer: {web_result_eng}
|
| 313 |
Names: {names}
|
| 314 |
""".strip()
|
| 315 |
else:
|
| 316 |
+
web_query = query_eng + "\nUse the following Suggested Answer as a reference to answer the question above in the Japanese.\n===\nSuggested Answer: " + web_result_eng + "\n"
|
| 317 |
|
| 318 |
|
| 319 |
return ss, web_query
|
|
|
|
| 329 |
# --------------------------------------
|
| 330 |
# Conversation Chain Template
|
| 331 |
# --------------------------------------
|
|
|
|
| 332 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
sys_chat_message = """
|
| 335 |
+
You are an AI concierge who carefully answers questions from customers based on references.
|
| 336 |
+
You understand what the customer wants to know, and give many specific details in Japanese
|
| 337 |
+
using sentences extracted from the following references when available. If you do not know
|
| 338 |
+
the answer, do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 339 |
+
""".replace("\n", "")
|
| 340 |
|
| 341 |
chat_common_format = """
|
| 342 |
===
|
| 343 |
Question: {query}
|
| 344 |
+
|
|
|
|
|
|
|
|
|
|
| 345 |
日本語の回答: """
|
| 346 |
|
| 347 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
|
|
|
| 351 |
# QA Chain Template (Stuff)
|
| 352 |
# --------------------------------------
|
| 353 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
| 354 |
+
# sys_qa_message = """
|
| 355 |
+
# You are an AI concierge who carefully answers questions from customers based on references.
|
| 356 |
+
# You understand what the customer wants to know from the Conversation History and Question,
|
| 357 |
+
# and give a specific answer in Japanese using sentences extracted from the following references.
|
| 358 |
+
# If you do not know the answer, do not make up an answer and reply,
|
| 359 |
+
# "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 360 |
+
# """.replace("\n", "")
|
| 361 |
+
|
| 362 |
+
# qa_common_format = """
|
| 363 |
+
# ===
|
| 364 |
+
# Question: {query}
|
| 365 |
+
# References: {context}
|
| 366 |
+
# ===
|
| 367 |
+
# Conversation History:
|
| 368 |
+
# {chat_history}
|
| 369 |
+
# ===
|
| 370 |
+
# 日本語の回答: """
|
| 371 |
|
| 372 |
qa_common_format = """
|
| 373 |
===
|
| 374 |
Question: {query}
|
| 375 |
References: {context}
|
| 376 |
+
|
|
|
|
|
|
|
|
|
|
| 377 |
日本語の回答: """
|
| 378 |
|
| 379 |
+
qa_template_std = f"{sys_chat_message}{qa_common_format}"
|
| 380 |
+
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{qa_common_format}[/INST]"
|
| 381 |
|
| 382 |
+
# qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
| 383 |
+
# qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
| 384 |
|
| 385 |
# --------------------------------------
|
| 386 |
# QA Chain Template (Map Reduce)
|
| 387 |
# --------------------------------------
|
| 388 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
| 389 |
query_generator_message = """
|
| 390 |
+
Referring to the "Conversation History", especially to the most recent conversation,
|
| 391 |
+
reformat the user's "Additional Question" into a specific question in Japanese by
|
| 392 |
+
filling in the missing subject, verb, objects, complements,and other necessary
|
| 393 |
+
information to get a better search result. Answer in 日本語(Japanese).
|
| 394 |
""".replace("\n", "")
|
| 395 |
|
| 396 |
query_generator_common_format = """
|
|
|
|
| 399 |
{chat_history}
|
| 400 |
|
| 401 |
[Additional Question] {query}
|
| 402 |
+
明確な質問文: """
|
| 403 |
|
| 404 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
| 405 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
| 406 |
|
| 407 |
|
| 408 |
# 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
question_prompt_message = """
|
| 411 |
+
From the following references, extract key information relevant to the question
|
| 412 |
+
and summarize it in a natural English sentence with clear subject, verb, object,
|
| 413 |
+
and complement.
|
| 414 |
+
""".replace("\n", "")
|
| 415 |
|
| 416 |
question_prompt_common_format = """
|
| 417 |
===
|
| 418 |
[Question] {query}
|
| 419 |
+
[References] {context}
|
| 420 |
+
[Key Information] """
|
| 421 |
|
| 422 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
| 423 |
question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
|
|
|
|
| 638 |
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
|
| 639 |
|
| 640 |
if ss.conversation_chain is None:
|
| 641 |
+
# chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
| 642 |
+
chat_prompt = PromptTemplate(input_variables=['query'], template=chat_template)
|
| 643 |
ss.conversation_chain = ConversationChain(
|
| 644 |
llm = ss.llm,
|
| 645 |
prompt = chat_prompt,
|
| 646 |
+
# memory = ss.memory,
|
| 647 |
input_key = "query",
|
| 648 |
output_key = "output_text",
|
| 649 |
verbose = True,
|
|
|
|
| 651 |
|
| 652 |
if ss.qa_chain is None:
|
| 653 |
if summarization_mode == "stuff":
|
| 654 |
+
# qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
|
| 655 |
+
qa_prompt = PromptTemplate(input_variables=['context', 'query'], template=qa_template)
|
| 656 |
+
# ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
| 657 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", prompt=qa_prompt, verbose=True)
|
| 658 |
|
| 659 |
elif summarization_mode == "map_reduce":
|
| 660 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
| 661 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
| 662 |
+
# ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
| 663 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, question_prompt=question_prompt, combine_prompt=combine_prompt, verbose=True)
|
| 664 |
|
| 665 |
if ss.web_summary_chain is None:
|
| 666 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
|
|
|
| 917 |
|
| 918 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
| 919 |
response = ss.conversation_chain.predict(query=query)
|
| 920 |
+
ss.memory.chat_memory.add_user_message(query)
|
| 921 |
+
ss.memory.chat_memory.add_ai_message(response)
|
| 922 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 923 |
return ss
|
| 924 |
|
|
|
|
| 956 |
if result["output_text"] != "":
|
| 957 |
response = result["output_text"] + sources
|
| 958 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 959 |
+
ss.memory.chat_memory.add_user_message(original_query)
|
| 960 |
+
ss.memory.chat_memory.add_ai_message(response)
|
| 961 |
return ss
|
| 962 |
+
# else:
|
| 963 |
# 空欄の場合は直近の履歴を削除してやり直し
|
| 964 |
+
# ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
|
| 965 |
|
| 966 |
# 3回の試行後も空欄の場合
|
| 967 |
response = "3回試行しましたが、情報製生成できませんでした。"
|