Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +96 -41
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -22,6 +22,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
|
|
| 22 |
from langchain.chat_models import ChatOpenAI
|
| 23 |
|
| 24 |
# LangChain
|
|
|
|
| 25 |
from langchain.llms import HuggingFacePipeline
|
| 26 |
from transformers import pipeline
|
| 27 |
|
|
@@ -45,8 +46,8 @@ import gradio as gr
|
|
| 45 |
from pypdf import PdfReader
|
| 46 |
import requests # DeepL API request
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
import
|
| 50 |
|
| 51 |
# --------------------------------------
|
| 52 |
# ユーザ別セッションの変数値を記録するクラス
|
|
@@ -69,6 +70,7 @@ class SessionState:
|
|
| 69 |
self.conversation_chain = None # ConversationChain
|
| 70 |
self.query_generator = None # Query Refiner with Chat history
|
| 71 |
self.qa_chain = None # load_qa_chain
|
|
|
|
| 72 |
self.embedded_urls = []
|
| 73 |
self.similarity_search_k = None # No. of similarity search documents to find.
|
| 74 |
self.summarization_mode = None # Stuff / Map Reduce / Refine
|
|
@@ -132,6 +134,33 @@ text_splitter = JPTextSplitter(
|
|
| 132 |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
| 133 |
)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# --------------------------------------
|
| 136 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
| 137 |
# --------------------------------------
|
|
@@ -175,11 +204,22 @@ def deepl_memory(ss: SessionState) -> (SessionState):
|
|
| 175 |
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
| 176 |
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
| 177 |
|
| 178 |
-
def web_search(
|
| 179 |
-
search = DuckDuckGoSearchRun()
|
| 180 |
web_result = search(query)
|
| 181 |
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
text = [query, web_result]
|
| 184 |
params = {
|
| 185 |
"auth_key": DEEPL_API_KEY,
|
|
@@ -193,19 +233,28 @@ def web_search(query, current_model) -> str:
|
|
| 193 |
response = request.json()
|
| 194 |
|
| 195 |
query = response["translations"][0]["text"]
|
| 196 |
-
web_result =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
|
| 199 |
|
| 200 |
-
return web_query
|
| 201 |
|
| 202 |
# --------------------------------------
|
| 203 |
# LangChain カスタムプロンプト各種
|
| 204 |
# llama tokenizer
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
# OpenAI tokenizer
|
| 208 |
-
#
|
| 209 |
# --------------------------------------
|
| 210 |
|
| 211 |
# --------------------------------------
|
|
@@ -214,19 +263,18 @@ def web_search(query, current_model) -> str:
|
|
| 214 |
|
| 215 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
| 216 |
sys_chat_message = """
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 221 |
""".replace("\n", "")
|
| 222 |
|
| 223 |
chat_common_format = """
|
| 224 |
===
|
| 225 |
Question: {query}
|
| 226 |
-
|
| 227 |
-
Conversation History:
|
| 228 |
{chat_history}
|
| 229 |
-
|
| 230 |
日本語の回答: """
|
| 231 |
|
| 232 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
|
@@ -238,21 +286,23 @@ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common
|
|
| 238 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
| 239 |
sys_qa_message = """
|
| 240 |
You are an AI concierge who carefully answers questions from customers based on references.
|
| 241 |
-
You understand what the customer wants to know from
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
""".replace("\n", "")
|
| 246 |
|
| 247 |
qa_common_format = """
|
| 248 |
===
|
| 249 |
Question: {query}
|
| 250 |
References: {context}
|
| 251 |
-
|
|
|
|
| 252 |
{chat_history}
|
| 253 |
-
|
| 254 |
日本語の回答: """
|
| 255 |
|
|
|
|
| 256 |
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
| 257 |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
| 258 |
|
|
@@ -262,8 +312,8 @@ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_forma
|
|
| 262 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
| 263 |
query_generator_message = """
|
| 264 |
Referring to the "Conversation History", reformat the user's "Additional Question"
|
| 265 |
-
to a specific question
|
| 266 |
-
|
| 267 |
""".replace("\n", "")
|
| 268 |
|
| 269 |
query_generator_common_format = """
|
|
@@ -272,7 +322,7 @@ query_generator_common_format = """
|
|
| 272 |
{chat_history}
|
| 273 |
|
| 274 |
[Additional Question] {query}
|
| 275 |
-
|
| 276 |
|
| 277 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
| 278 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
|
@@ -287,8 +337,8 @@ and complement.
|
|
| 287 |
|
| 288 |
question_prompt_common_format = """
|
| 289 |
===
|
| 290 |
-
[references] {context}
|
| 291 |
[Question] {query}
|
|
|
|
| 292 |
[Summary] """
|
| 293 |
|
| 294 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
|
@@ -305,17 +355,14 @@ If you do not know the answer, do not make up an answer and reply,
|
|
| 305 |
|
| 306 |
combine_prompt_common_format = """
|
| 307 |
===
|
| 308 |
-
Question:
|
| 309 |
-
{query}
|
| 310 |
-
===
|
| 311 |
Reference: {summaries}
|
| 312 |
-
===
|
| 313 |
日本語の回答: """
|
| 314 |
|
|
|
|
| 315 |
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
|
| 316 |
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
|
| 317 |
|
| 318 |
-
|
| 319 |
# --------------------------------------
|
| 320 |
# ConversationSummaryBufferMemoryの要約プロンプト
|
| 321 |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
|
@@ -508,6 +555,10 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
| 508 |
# --------------------------------------
|
| 509 |
# Conversation/QAチェーンの設定
|
| 510 |
# --------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
if ss.conversation_chain is None:
|
| 512 |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
| 513 |
ss.conversation_chain = ConversationChain(
|
|
@@ -525,13 +576,14 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
| 525 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
| 526 |
|
| 527 |
elif summarization_mode == "map_reduce":
|
| 528 |
-
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
|
| 529 |
-
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt)
|
| 530 |
-
|
| 531 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
| 532 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
| 533 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
return ss
|
| 536 |
|
| 537 |
def initialize_db(ss: SessionState) -> SessionState:
|
|
@@ -761,16 +813,16 @@ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (Sess
|
|
| 761 |
# QA Model
|
| 762 |
if qa_flag is True and ss.embeddings is not None and ss.db is not None:
|
| 763 |
if web_flag:
|
| 764 |
-
web_query = web_search(
|
| 765 |
ss = qa_predict(ss, web_query)
|
| 766 |
ss.memory.chat_memory.messages[-2].content = query
|
| 767 |
else:
|
| 768 |
-
ss = qa_predict(ss, query)
|
| 769 |
|
| 770 |
# Chat Model
|
| 771 |
else:
|
| 772 |
if web_flag:
|
| 773 |
-
web_query = web_search(
|
| 774 |
ss = chat_predict(ss, web_query)
|
| 775 |
ss.memory.chat_memory.messages[-2].content = query
|
| 776 |
else:
|
|
@@ -788,6 +840,8 @@ def chat_predict(ss: SessionState, query) -> SessionState:
|
|
| 788 |
|
| 789 |
def qa_predict(ss: SessionState, query) -> SessionState:
|
| 790 |
|
|
|
|
|
|
|
| 791 |
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
| 792 |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 793 |
query = query.strip().replace("\n", "<NL>")
|
|
@@ -829,7 +883,7 @@ def qa_predict(ss: SessionState, query) -> SessionState:
|
|
| 829 |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
| 830 |
|
| 831 |
# ユーザーメッセージと AI メッセージの追加
|
| 832 |
-
ss.memory.chat_memory.add_user_message(
|
| 833 |
ss.memory.chat_memory.add_ai_message(response)
|
| 834 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
| 835 |
return ss
|
|
@@ -1028,4 +1082,5 @@ with gr.Blocks() as demo:
|
|
| 1028 |
|
| 1029 |
if __name__ == "__main__":
|
| 1030 |
demo.queue(concurrency_count=5)
|
| 1031 |
-
demo.launch(debug=True)
|
|
|
|
|
|
| 22 |
from langchain.chat_models import ChatOpenAI
|
| 23 |
|
| 24 |
# LangChain
|
| 25 |
+
import langchain
|
| 26 |
from langchain.llms import HuggingFacePipeline
|
| 27 |
from transformers import pipeline
|
| 28 |
|
|
|
|
| 46 |
from pypdf import PdfReader
|
| 47 |
import requests # DeepL API request
|
| 48 |
|
| 49 |
+
# Mecab
|
| 50 |
+
import MeCab
|
| 51 |
|
| 52 |
# --------------------------------------
|
| 53 |
# ユーザ別セッションの変数値を記録するクラス
|
|
|
|
| 70 |
self.conversation_chain = None # ConversationChain
|
| 71 |
self.query_generator = None # Query Refiner with Chat history
|
| 72 |
self.qa_chain = None # load_qa_chain
|
| 73 |
+
self.web_summary_chain = None # Summarize web search result
|
| 74 |
self.embedded_urls = []
|
| 75 |
self.similarity_search_k = None # No. of similarity search documents to find.
|
| 76 |
self.summarization_mode = None # Stuff / Map Reduce / Refine
|
|
|
|
| 134 |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
| 135 |
)
|
| 136 |
|
| 137 |
+
# --------------------------------------
|
| 138 |
+
# 文中から人名を抽出
|
| 139 |
+
# --------------------------------------
|
| 140 |
+
def name_detector(text: str) -> list:
|
| 141 |
+
mecab = MeCab.Tagger()
|
| 142 |
+
mecab.parse('') # ←バグ対応
|
| 143 |
+
node = mecab.parseToNode(text).next
|
| 144 |
+
names = []
|
| 145 |
+
|
| 146 |
+
while node:
|
| 147 |
+
if node.feature.split(',')[3] == "姓":
|
| 148 |
+
if node.next and node.next.feature.split(',')[3] == "名":
|
| 149 |
+
names.append(str(node.surface) + str(node.next.surface))
|
| 150 |
+
else:
|
| 151 |
+
names.append(node.surface)
|
| 152 |
+
if node.feature.split(',')[3] == "名":
|
| 153 |
+
if node.prev and node.prev.feature.split(',')[3] == "姓":
|
| 154 |
+
pass
|
| 155 |
+
else:
|
| 156 |
+
names.append(str(node.surface))
|
| 157 |
+
|
| 158 |
+
node = node.next
|
| 159 |
+
|
| 160 |
+
names = list(set(names))
|
| 161 |
+
|
| 162 |
+
return names
|
| 163 |
+
|
| 164 |
# --------------------------------------
|
| 165 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
| 166 |
# --------------------------------------
|
|
|
|
| 204 |
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
| 205 |
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
| 206 |
|
| 207 |
+
def web_search(ss: SessionState, query) -> (SessionState, str):
|
| 208 |
+
search = DuckDuckGoSearchRun(verbose=True)
|
| 209 |
web_result = search(query)
|
| 210 |
|
| 211 |
+
# 人名の抽出
|
| 212 |
+
names = []
|
| 213 |
+
names.extend(name_detector(query))
|
| 214 |
+
names.extend(name_detector(web_result))
|
| 215 |
+
if len(names)==0:
|
| 216 |
+
names = ""
|
| 217 |
+
elif len(names)==1:
|
| 218 |
+
names = names[0]
|
| 219 |
+
else:
|
| 220 |
+
names = ", ".join(names)
|
| 221 |
+
|
| 222 |
+
if ss.current_model == "gpt-3.5-turbo":
|
| 223 |
text = [query, web_result]
|
| 224 |
params = {
|
| 225 |
"auth_key": DEEPL_API_KEY,
|
|
|
|
| 233 |
response = request.json()
|
| 234 |
|
| 235 |
query = response["translations"][0]["text"]
|
| 236 |
+
web_result = response["translations"][1]["text"]
|
| 237 |
+
web_result = ss.web_summary_chain({'query': query, 'context': web_result})['text']
|
| 238 |
+
|
| 239 |
+
if names != "":
|
| 240 |
+
web_query = f"""
|
| 241 |
+
{query}
|
| 242 |
+
Use the following information as a reference to answer the question above in Japanese. When translating names of Japanese people, refer to Japanese Names as a translation guide.
|
| 243 |
+
Reference: {web_result}
|
| 244 |
+
Japanese Names: {names}
|
| 245 |
+
""".strip()
|
| 246 |
+
else:
|
| 247 |
+
web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
|
| 248 |
|
|
|
|
| 249 |
|
| 250 |
+
return ss, web_query
|
| 251 |
|
| 252 |
# --------------------------------------
|
| 253 |
# LangChain カスタムプロンプト各種
|
| 254 |
# llama tokenizer
|
| 255 |
+
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
|
|
|
|
| 256 |
# OpenAI tokenizer
|
| 257 |
+
# https://platform.openai.com/tokenizer
|
| 258 |
# --------------------------------------
|
| 259 |
|
| 260 |
# --------------------------------------
|
|
|
|
| 263 |
|
| 264 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
| 265 |
sys_chat_message = """
|
| 266 |
+
You are an outstanding AI concierge. You understand your customers' needs from their questions and answer
|
| 267 |
+
them with many specific and detailed information in Japanese. If you do not know the answer to a question,
|
| 268 |
+
do make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます". Ignore Conversation History.
|
|
|
|
| 269 |
""".replace("\n", "")
|
| 270 |
|
| 271 |
chat_common_format = """
|
| 272 |
===
|
| 273 |
Question: {query}
|
| 274 |
+
===
|
| 275 |
+
Conversation History(Ignore):
|
| 276 |
{chat_history}
|
| 277 |
+
===
|
| 278 |
日本語の回答: """
|
| 279 |
|
| 280 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
|
|
|
| 286 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
| 287 |
sys_qa_message = """
|
| 288 |
You are an AI concierge who carefully answers questions from customers based on references.
|
| 289 |
+
You understand what the customer wants to know from Question, and give a specific answer in
|
| 290 |
+
Japanese using sentences extracted from the following references. If you do not know the answer,
|
| 291 |
+
do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
|
| 292 |
+
Ignore Conversation History.
|
| 293 |
""".replace("\n", "")
|
| 294 |
|
| 295 |
qa_common_format = """
|
| 296 |
===
|
| 297 |
Question: {query}
|
| 298 |
References: {context}
|
| 299 |
+
===
|
| 300 |
+
Conversation History(Ignore):
|
| 301 |
{chat_history}
|
| 302 |
+
===
|
| 303 |
日本語の回答: """
|
| 304 |
|
| 305 |
+
|
| 306 |
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
| 307 |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
| 308 |
|
|
|
|
| 312 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
| 313 |
query_generator_message = """
|
| 314 |
Referring to the "Conversation History", reformat the user's "Additional Question"
|
| 315 |
+
to a specific question by filling in the missing subject, verb, objects, complements,
|
| 316 |
+
and other necessary information to get a better search result. Answer in Japanese.
|
| 317 |
""".replace("\n", "")
|
| 318 |
|
| 319 |
query_generator_common_format = """
|
|
|
|
| 322 |
{chat_history}
|
| 323 |
|
| 324 |
[Additional Question] {query}
|
| 325 |
+
明確な日本語の質問文: """
|
| 326 |
|
| 327 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
| 328 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
|
|
|
| 337 |
|
| 338 |
question_prompt_common_format = """
|
| 339 |
===
|
|
|
|
| 340 |
[Question] {query}
|
| 341 |
+
[references] {context}
|
| 342 |
[Summary] """
|
| 343 |
|
| 344 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
|
|
|
| 355 |
|
| 356 |
combine_prompt_common_format = """
|
| 357 |
===
|
| 358 |
+
Question: {query}
|
|
|
|
|
|
|
| 359 |
Reference: {summaries}
|
|
|
|
| 360 |
日本語の回答: """
|
| 361 |
|
| 362 |
+
|
| 363 |
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
|
| 364 |
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
|
| 365 |
|
|
|
|
| 366 |
# --------------------------------------
|
| 367 |
# ConversationSummaryBufferMemoryの要約プロンプト
|
| 368 |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
|
|
|
| 555 |
# --------------------------------------
|
| 556 |
# Conversation/QAチェーンの設定
|
| 557 |
# --------------------------------------
|
| 558 |
+
if ss.query_generator is None:
|
| 559 |
+
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
|
| 560 |
+
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
|
| 561 |
+
|
| 562 |
if ss.conversation_chain is None:
|
| 563 |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
| 564 |
ss.conversation_chain = ConversationChain(
|
|
|
|
| 576 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
| 577 |
|
| 578 |
elif summarization_mode == "map_reduce":
|
|
|
|
|
|
|
|
|
|
| 579 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
| 580 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
| 581 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
| 582 |
|
| 583 |
+
if ss.web_summary_chain is None:
|
| 584 |
+
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
| 585 |
+
ss.web_summary_chain = LLMChain(llm=ss.llm, prompt=question_prompt, verbose=True)
|
| 586 |
+
|
| 587 |
return ss
|
| 588 |
|
| 589 |
def initialize_db(ss: SessionState) -> SessionState:
|
|
|
|
| 813 |
# QA Model
|
| 814 |
if qa_flag is True and ss.embeddings is not None and ss.db is not None:
|
| 815 |
if web_flag:
|
| 816 |
+
ss, web_query = web_search(ss, query)
|
| 817 |
ss = qa_predict(ss, web_query)
|
| 818 |
ss.memory.chat_memory.messages[-2].content = query
|
| 819 |
else:
|
| 820 |
+
ss = qa_predict(ss, query)
|
| 821 |
|
| 822 |
# Chat Model
|
| 823 |
else:
|
| 824 |
if web_flag:
|
| 825 |
+
ss, web_query = web_search(ss, query)
|
| 826 |
ss = chat_predict(ss, web_query)
|
| 827 |
ss.memory.chat_memory.messages[-2].content = query
|
| 828 |
else:
|
|
|
|
| 840 |
|
| 841 |
def qa_predict(ss: SessionState, query) -> SessionState:
|
| 842 |
|
| 843 |
+
original_query = query
|
| 844 |
+
|
| 845 |
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
| 846 |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
| 847 |
query = query.strip().replace("\n", "<NL>")
|
|
|
|
| 883 |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
| 884 |
|
| 885 |
# ユーザーメッセージと AI メッセージの追加
|
| 886 |
+
ss.memory.chat_memory.add_user_message(original_query.replace("<NL>", "\n"))
|
| 887 |
ss.memory.chat_memory.add_ai_message(response)
|
| 888 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
| 889 |
return ss
|
|
|
|
| 1082 |
|
| 1083 |
if __name__ == "__main__":
|
| 1084 |
demo.queue(concurrency_count=5)
|
| 1085 |
+
demo.launch(debug=True,)
|
| 1086 |
+
|
requirements.txt
CHANGED
|
@@ -21,4 +21,6 @@ numpy==1.23.5
|
|
| 21 |
pandas==1.5.3
|
| 22 |
chromedriver-autoinstaller
|
| 23 |
chromedriver-binary
|
| 24 |
-
duckduckgo-search==3.8.5
|
|
|
|
|
|
|
|
|
| 21 |
pandas==1.5.3
|
| 22 |
chromedriver-autoinstaller
|
| 23 |
chromedriver-binary
|
| 24 |
+
duckduckgo-search==3.8.5
|
| 25 |
+
mecab-python3==1.0.6
|
| 26 |
+
unidic-lite==1.0.8
|