Spaces:
Sleeping
Sleeping
| # -------------------------------------- | |
| # Chat with Documents | |
| # キカガク 2023.4月期 最終成果アプリ | |
| # Copyright. cawacci | |
| # -------------------------------------- | |
| # -------------------------------------- | |
| # Libraries | |
| # -------------------------------------- | |
| import os | |
| import time | |
| import gc # メモリ解放 | |
| import re # 正規表現で文章をクリーンアップ | |
| import regex # 漢字抽出で利用 | |
| # HuggingFace | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # OpenAI | |
| import openai | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.chat_models import ChatOpenAI | |
| # LangChain | |
| import langchain | |
| from langchain.llms import HuggingFacePipeline | |
| from transformers import pipeline | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import LLMChain, VectorDBQA | |
| from langchain.vectorstores import Chroma | |
| from langchain import PromptTemplate, ConversationChain | |
| from langchain.chains.question_answering import load_qa_chain # QA Chat | |
| from langchain.document_loaders import SeleniumURLLoader # URL取得 | |
| from langchain.docstore.document import Document # テキストをドキュメント化 | |
| from langchain.memory import ConversationSummaryBufferMemory # チャット履歴 | |
| from typing import Any | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.tools import DuckDuckGoSearchRun | |
| # Gradio | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| import requests # DeepL API request | |
| # Mecab | |
| import MeCab | |
| # -------------------------------------- | |
| # ユーザ別セッションの変数値を記録するクラス | |
| # (参考)https://blog.shikoan.com/gradio-state/ | |
| # -------------------------------------- | |
| class SessionState: | |
| def __init__(self): | |
| # Hugging Face | |
| self.tokenizer = None | |
| self.pipe = None | |
| self.model = None | |
| # LangChain | |
| self.llm = None | |
| self.embeddings = None | |
| self.current_model = "" | |
| self.current_embedding = "" | |
| self.db = None # Vector DB | |
| self.memory = None # Langchain Chat Memory | |
| self.conversation_chain = None # ConversationChain | |
| self.query_generator = None # Query Refiner with Chat history | |
| self.qa_chain = None # load_qa_chain | |
| self.web_summary_chain = None # Summarize web search result | |
| self.embedded_urls = [] | |
| self.similarity_search_k = None # No. of similarity search documents to find. | |
| self.summarization_mode = None # Stuff / Map Reduce / Refine | |
| # Apps | |
| self.dialogue = [] # Recent Chat History for display | |
| # -------------------------------------- | |
| # Empty Cache | |
| # -------------------------------------- | |
| def cache_clear(self): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() # GPU Memory Clear | |
| gc.collect() # CPU Memory Clear | |
| # -------------------------------------- | |
| # Clear Models (llm: llm model, embd: embeddings, db: vectordb) | |
| # -------------------------------------- | |
| def clear_memory(self, llm=False, embd=False, db=False): | |
| # DB | |
| if db and self.db: | |
| self.db.delete_collection() | |
| self.db = None | |
| self.embedded_urls = [] | |
| # Embeddings model | |
| if llm or embd: | |
| self.embeddings = None | |
| self.current_embedding = "" | |
| self.qa_chain = None | |
| # LLM model | |
| if llm: | |
| self.llm = None | |
| self.pipe = None | |
| self.model = None | |
| self.current_model = "" | |
| self.tokenizer = None | |
| self.memory = None | |
| self.chat_history = [] # ←必要性を要検証 | |
| self.cache_clear() | |
| # -------------------------------------- | |
| # メモリを使用しない ConversationChainを自作 | |
| # -------------------------------------- | |
| from typing import Dict, List | |
| from langchain.chains.conversation.prompt import PROMPT | |
| from langchain.chains.llm import LLMChain | |
| from langchain.pydantic_v1 import Extra, Field, root_validator | |
| from langchain.schema import BasePromptTemplate | |
| class ConversationChain(LLMChain): | |
| """Chain to have a conversation without loading context from memory. | |
| Example: | |
| .. code-block:: python | |
| from langchain import ConversationChainWithoutMemory, OpenAI | |
| conversation = ConversationChainWithoutMemory(llm=OpenAI()) | |
| """ | |
| prompt: BasePromptTemplate = PROMPT | |
| """Default conversation prompt to use.""" | |
| input_key: str = "input" #: :meta private: | |
| output_key: str = "response" #: :meta private: | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def input_keys(self) -> List[str]: | |
| """Use this since so some prompt vars come from history.""" | |
| return [self.input_key] | |
| def validate_prompt_input_variables(cls, values: Dict) -> Dict: | |
| """Validate that prompt input variables are consistent without memory.""" | |
| input_key = values["input_key"] | |
| prompt_variables = values["prompt"].input_variables | |
| if input_key not in prompt_variables: | |
| raise ValueError( | |
| f"The prompt expects {prompt_variables}, but {input_key} is not found." | |
| ) | |
| return values | |
| # -------------------------------------- | |
| # 自作TextSplitter(テキストをLLMのトークン数内に分割) | |
| # (参考)https://www.sato-susumu.com/entry/2023/04/30/131338 | |
| # → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加 | |
| # -------------------------------------- | |
| class JPTextSplitter(RecursiveCharacterTextSplitter): | |
| def __init__(self, **kwargs: Any): | |
| separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""] | |
| super().__init__(separators=separators, **kwargs) | |
| # チャンクの分割 | |
| chunk_size = 512 | |
| chunk_overlap = 35 | |
| text_splitter = JPTextSplitter( | |
| chunk_size = chunk_size, # チャンクの最大文字数 | |
| chunk_overlap = chunk_overlap, # オーバーラップの最大文字数 | |
| ) | |
| # -------------------------------------- | |
| # 文中から人名を抽出 | |
| # -------------------------------------- | |
| def name_detector(text: str) -> list: | |
| mecab = MeCab.Tagger() | |
| mecab.parse('') # ←バグ対応 | |
| node = mecab.parseToNode(text).next | |
| names = [] | |
| while node: | |
| if node.feature.split(',')[3] == "姓": | |
| if node.next and node.next.feature.split(',')[3] == "名": | |
| names.append(str(node.surface) + str(node.next.surface)) | |
| else: | |
| names.append(node.surface) | |
| if node.feature.split(',')[3] == "名": | |
| if node.prev and node.prev.feature.split(',')[3] == "姓": | |
| pass | |
| else: | |
| names.append(str(node.surface)) | |
| node = node.next | |
| # ユニークな値を抽出し、その後漢字を含む値のみとする | |
| names = filter_kanji(list(set(names))) | |
| return names | |
| # -------------------------------------- | |
| # リストから漢字を含む値だけを抽出する | |
| # -------------------------------------- | |
| def filter_kanji(lst) -> list: | |
| def contains_kanji(s): | |
| p = regex.compile(r'\p{Script=Han}+') | |
| return bool(p.search(s)) | |
| return [item for item in lst if contains_kanji(item)] | |
| # -------------------------------------- | |
| # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時) | |
| # -------------------------------------- | |
| DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate" | |
| DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") | |
| def deepl_memory(ss: SessionState) -> (SessionState): | |
| if ss.current_model == "gpt-3.5-turbo": | |
| # メモリから会話履歴を取得 | |
| user_message = ss.memory.chat_memory.messages[-2].content | |
| ai_message = ss.memory.chat_memory.messages[-1].content | |
| text = [user_message, ai_message] | |
| # DeepL設定 | |
| params = { | |
| "auth_key": DEEPL_API_KEY, | |
| "text": text, | |
| "target_lang": "EN", | |
| "source_lang": "JA", | |
| "tag_handling": "xml", | |
| "igonere_tags": "x", | |
| } | |
| request = requests.post(DEEPL_API_ENDPOINT, data=params) | |
| request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。 | |
| response = request.json() | |
| # JSONから翻訳文を取得 | |
| user_message = response["translations"][0]["text"] | |
| ai_message = response["translations"][1]["text"] | |
| # memoryの最後の会話を削除し、翻訳文を追加 | |
| ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2] | |
| ss.memory.chat_memory.add_user_message(user_message) | |
| ss.memory.chat_memory.add_ai_message(ai_message) | |
| return ss | |
| # -------------------------------------- | |
| # DuckDuckGo Web検索結果を入力プロンプトに追加 | |
| # -------------------------------------- | |
| # DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate" | |
| # DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") | |
| def web_search(ss: SessionState, query) -> (SessionState, str): | |
| search = DuckDuckGoSearchRun(verbose=True) | |
| names = [] | |
| names.extend(name_detector(query)) | |
| for i in range(3): | |
| web_result = search(query) | |
| if ss.current_model == "gpt-3.5-turbo": | |
| text = [query, web_result] | |
| params = { | |
| "auth_key": DEEPL_API_KEY, | |
| "text": text, | |
| "target_lang": "EN", | |
| "source_lang": "JA", | |
| "tag_handling": "xml", | |
| "ignore_tags": "x", | |
| } | |
| request = requests.post(DEEPL_API_ENDPOINT, data=params) | |
| response = request.json() | |
| query_eng = response["translations"][0]["text"] | |
| web_result_eng = response["translations"][1]["text"] | |
| web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text'] | |
| if "$$NO INFO$$" in web_result_eng: | |
| web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text'] | |
| if "$$NO INFO$$" not in web_result_eng: | |
| break | |
| # 検索結果から人名を抽出し、テキスト化 | |
| names.extend(name_detector(web_result)) | |
| if len(names)==0: | |
| names = "" | |
| elif len(names)==1: | |
| names = names[0] | |
| else: | |
| names = ", ".join(names) | |
| # Web検索結果を含むQueryを渡す。 | |
| if names != "": | |
| web_query = f""" | |
| {query_eng} | |
| Use the following Suggested Answer as a reference to answer the question above in Japanese. When translating names of people, refer to Names as a translation guide. | |
| Suggested Answer: {web_result_eng} | |
| Names: {names} | |
| """.strip() | |
| else: | |
| web_query = query_eng + "\nUse the following Suggested Answer as a reference to answer the question above in the Japanese.\n===\nSuggested Answer: " + web_result_eng + "\n" | |
| return ss, web_query | |
| # -------------------------------------- | |
| # LangChain カスタムプロンプト各種 | |
| # llama tokenizer | |
| # https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/ | |
| # OpenAI tokenizer | |
| # https://platform.openai.com/tokenizer | |
| # -------------------------------------- | |
| # -------------------------------------- | |
| # Conversation Chain Template | |
| # -------------------------------------- | |
| # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162 | |
| sys_chat_message = """ | |
| You are an AI concierge who carefully answers questions from customers based on references. | |
| You understand what the customer wants to know, and give many specific details in Japanese | |
| using sentences extracted from the following references when available. If you do not know | |
| the answer, do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます". | |
| """.replace("\n", "") | |
| chat_common_format = """ | |
| === | |
| Question: {query} | |
| 日本語の回答: """ | |
| chat_template_std = f"{sys_chat_message}{chat_common_format}" | |
| chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]" | |
| # -------------------------------------- | |
| # QA Chain Template (Stuff) | |
| # -------------------------------------- | |
| # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225 | |
| # sys_qa_message = """ | |
| # You are an AI concierge who carefully answers questions from customers based on references. | |
| # You understand what the customer wants to know from the Conversation History and Question, | |
| # and give a specific answer in Japanese using sentences extracted from the following references. | |
| # If you do not know the answer, do not make up an answer and reply, | |
| # "誠に申し訳ございませんが、その点についてはわかりかねます". | |
| # """.replace("\n", "") | |
| # qa_common_format = """ | |
| # === | |
| # Question: {query} | |
| # References: {context} | |
| # === | |
| # Conversation History: | |
| # {chat_history} | |
| # === | |
| # 日本語の回答: """ | |
| qa_common_format = """ | |
| === | |
| Question: {query} | |
| References: {context} | |
| 日本語の回答: """ | |
| qa_template_std = f"{sys_chat_message}{qa_common_format}" | |
| qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{qa_common_format}[/INST]" | |
| # qa_template_std = f"{sys_qa_message}{qa_common_format}" | |
| # qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]" | |
| # -------------------------------------- | |
| # QA Chain Template (Map Reduce) | |
| # -------------------------------------- | |
| # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト | |
| query_generator_message = """ | |
| Referring to the "Conversation History", especially to the most recent conversation, | |
| reformat the user's "Additional Question" into a specific question in Japanese by | |
| filling in the missing subject, verb, objects, complements,and other necessary | |
| information to get a better search result. Answer in 日本語(Japanese). | |
| """.replace("\n", "") | |
| query_generator_common_format = """ | |
| === | |
| [Conversation History] | |
| {chat_history} | |
| [Additional Question] {query} | |
| 明確な質問文: """ | |
| query_generator_template_std = f"{query_generator_message}{query_generator_common_format}" | |
| query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]" | |
| # 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト | |
| question_prompt_message = """ | |
| From the following references, extract key information relevant to the question | |
| and summarize it in a natural English sentence with clear subject, verb, object, | |
| and complement. | |
| """.replace("\n", "") | |
| question_prompt_common_format = """ | |
| === | |
| [Question] {query} | |
| [References] {context} | |
| [Key Information] """ | |
| question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}" | |
| question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]" | |
| # 3. 生成された質問文とベクターデータベースの要約をもとに、回答を行うchain のプロンプト | |
| combine_prompt_message = """ | |
| You are an AI concierge who carefully answers questions from customers based on references. | |
| Provide a specific answer in Japanese using sentences extracted from the following references. | |
| If you do not know the answer, do not make up an answer and reply, | |
| "誠に申し訳ございませんが、その点についてはわかりかねます". | |
| """.replace("\n", "") | |
| combine_prompt_common_format = """ | |
| === | |
| Question: {query} | |
| Reference: {summaries} | |
| 日本語の回答: """ | |
| combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}" | |
| combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]" | |
| # -------------------------------------- | |
| # ConversationSummaryBufferMemoryの要約プロンプト | |
| # ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49 | |
| # -------------------------------------- | |
| # Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297 | |
| conversation_summary_template = """ | |
| Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation". | |
| === | |
| Example | |
| [Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool. | |
| [New Conversation] | |
| Human: なぜ人工知能が良いツールだと思いますか? | |
| AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。 | |
| [New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential. | |
| === | |
| [Current Summary] {summary} | |
| [New Conversation] | |
| {new_lines} | |
| [New Summary] | |
| """.strip() | |
| # モデル読み込み | |
| def load_models( | |
| ss: SessionState, | |
| model_id: str, | |
| embedding_id: str, | |
| openai_api_key: str, | |
| load_in_8bit: bool, | |
| verbose: bool, | |
| temperature: float, | |
| similarity_search_k: int, | |
| summarization_mode: str, | |
| min_length: int, | |
| max_new_tokens: int, | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| num_return_sequences: int, | |
| ) -> (SessionState, str): | |
| # -------------------------------------- | |
| # 変数の保存 | |
| # -------------------------------------- | |
| ss.similarity_search_k = similarity_search_k | |
| ss.summarization_mode = summarization_mode | |
| # -------------------------------------- | |
| # OpenAI API KEYの確認 | |
| # -------------------------------------- | |
| if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"): | |
| # 前処理 | |
| if not os.environ["OPENAI_API_KEY"]: | |
| status_message = "❌ OpenAI API KEY を設定してください" | |
| return ss, status_message | |
| # -------------------------------------- | |
| # LLMの設定 | |
| # -------------------------------------- | |
| # OpenAI Model | |
| if model_id == "gpt-3.5-turbo": | |
| ss.clear_memory(llm=True, db=True) | |
| ss.llm = ChatOpenAI( | |
| model_name = model_id, | |
| temperature = temperature, | |
| verbose = verbose, | |
| max_tokens = max_new_tokens, | |
| ) | |
| # Hugging Face GPT Model | |
| else: | |
| ss.clear_memory(llm=True, db=True) | |
| if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
| ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) | |
| else: | |
| ss.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| ss.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| load_in_8bit = load_in_8bit, | |
| torch_dtype = torch.float16, | |
| device_map = "auto", | |
| ) | |
| ss.pipe = pipeline( | |
| "text-generation", | |
| model = ss.model, | |
| tokenizer = ss.tokenizer, | |
| min_length = min_length, | |
| max_new_tokens = max_new_tokens, | |
| do_sample = True, | |
| top_k = top_k, | |
| top_p = top_p, | |
| repetition_penalty = repetition_penalty, | |
| num_return_sequences = num_return_sequences, | |
| temperature = temperature, | |
| ) | |
| ss.llm = HuggingFacePipeline(pipeline=ss.pipe) | |
| # -------------------------------------- | |
| # 埋め込みモデルの設定 | |
| # -------------------------------------- | |
| if ss.current_embedding == embedding_id: | |
| pass | |
| else: | |
| # Reset embeddings and vectordb | |
| ss.clear_memory(embd=True, db=True) | |
| if embedding_id == "None": | |
| pass | |
| # OpenAI | |
| elif embedding_id == "text-embedding-ada-002": | |
| ss.embeddings = OpenAIEmbeddings() | |
| # Hugging Face | |
| else: | |
| ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id) | |
| # -------------------------------------- | |
| # チェーンの設定 | |
| #--------------------------------------- | |
| ss = set_chains(ss, summarization_mode) | |
| # -------------------------------------- | |
| # 現在のモデル名を SessionStateオブジェクトに保存 | |
| #--------------------------------------- | |
| ss.current_model = model_id | |
| ss.current_embedding = embedding_id | |
| # Status Message | |
| status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding | |
| return ss, status_message | |
| # -------------------------------------- | |
| # Conversation/QA Chain 呼び出し統合 | |
| # -------------------------------------- | |
| def set_chains(ss: SessionState, summarization_mode) -> SessionState: | |
| # モデルに合わせて chat_template を設定 | |
| human_prefix = "Human: " | |
| ai_prefix = "AI: " | |
| chat_template = chat_template_std | |
| qa_template = qa_template_std | |
| query_generator_template = query_generator_template_std | |
| question_template = question_prompt_template_std | |
| combine_template = combine_prompt_template_std | |
| if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
| # Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照) | |
| chat_template = chat_template.replace("\n", "<NL>") | |
| qa_template = qa_template.replace("\n", "<NL>") | |
| query_generator_template = query_generator_template_std.replace("\n", "<NL>") | |
| question_template = question_prompt_template_std.replace("\n", "<NL>") | |
| combine_template = combine_prompt_template_std.replace("\n", "<NL>") | |
| human_prefix = "ユーザー: " | |
| ai_prefix = "システム: " | |
| elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"): | |
| # ELYZAモデル向けのテンプレート設定 | |
| chat_template = chat_template_llama2 | |
| qa_template = qa_template_llama2 | |
| query_generator_template = query_generator_template_llama2 | |
| question_template = question_prompt_template_llama2 | |
| combine_template = combine_prompt_template_llama2 | |
| # -------------------------------------- | |
| # メモリの設定 | |
| # -------------------------------------- | |
| if ss.memory is None: | |
| conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template) | |
| ss.memory = ConversationSummaryBufferMemory( | |
| llm = ss.llm, | |
| memory_key = "chat_history", | |
| input_key = "query", | |
| output_key = "output_text", | |
| return_messages = False, | |
| human_prefix = human_prefix, | |
| ai_prefix = ai_prefix, | |
| max_token_limit = 1024, | |
| prompt = conversation_summary_prompt, | |
| ) | |
| # -------------------------------------- | |
| # Conversation/QAチェーンの設定 | |
| # -------------------------------------- | |
| if ss.query_generator is None: | |
| query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"]) | |
| ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True) | |
| if ss.conversation_chain is None: | |
| # chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template) | |
| chat_prompt = PromptTemplate(input_variables=['query'], template=chat_template) | |
| ss.conversation_chain = ConversationChain( | |
| llm = ss.llm, | |
| prompt = chat_prompt, | |
| # memory = ss.memory, | |
| input_key = "query", | |
| output_key = "output_text", | |
| verbose = True, | |
| ) | |
| if ss.qa_chain is None: | |
| if summarization_mode == "stuff": | |
| # qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template) | |
| qa_prompt = PromptTemplate(input_variables=['context', 'query'], template=qa_template) | |
| # ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt) | |
| ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", prompt=qa_prompt, verbose=True) | |
| elif summarization_mode == "map_reduce": | |
| question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"]) | |
| combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"]) | |
| # ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt) | |
| ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, question_prompt=question_prompt, combine_prompt=combine_prompt, verbose=True) | |
| if ss.web_summary_chain is None: | |
| question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"]) | |
| ss.web_summary_chain = LLMChain(llm=ss.llm, prompt=question_prompt, verbose=True) | |
| return ss | |
| def initialize_db(ss: SessionState) -> SessionState: | |
| # client = chromadb.PersistentClient(path="./db") | |
| ss.db = Chroma( | |
| collection_name = "user_reference", | |
| embedding_function = ss.embeddings, | |
| # client = client | |
| ) | |
| return ss | |
| def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState: | |
| # -------------------------------------- | |
| # 文章構成と不要な文字列の削除 | |
| # -------------------------------------- | |
| for i in range(len(ref_documents)): | |
| content = ref_documents[i].page_content.strip() | |
| # -------------------------------------- | |
| # PDFの場合は読み取りエラー対策で文書修正を強めに実施 | |
| # -------------------------------------- | |
| if ".pdf" in ref_documents[i].metadata['source']: | |
| pdf_replacement_sets = [ | |
| ('\n ', '**PLACEHOLDER+SPACE**'), | |
| ('\n\u3000', '**PLACEHOLDER+SPACE**'), | |
| ('.\n', '。**PLACEHOLDER**'), | |
| (',\n', '。**PLACEHOLDER**'), | |
| ('?\n', '。**PLACEHOLDER**'), | |
| ('!\n', '。**PLACEHOLDER**'), | |
| ('!\n', '。**PLACEHOLDER**'), | |
| ('。\n', '。**PLACEHOLDER**'), | |
| ('!\n', '!**PLACEHOLDER**'), | |
| (')\n', '!**PLACEHOLDER**'), | |
| (']\n', '!**PLACEHOLDER**'), | |
| ('?\n', '?**PLACEHOLDER**'), | |
| (')\n', '?**PLACEHOLDER**'), | |
| ('】\n', '?**PLACEHOLDER**'), | |
| ] | |
| for original, replacement in pdf_replacement_sets: | |
| content = content.replace(original, replacement) | |
| content = content.replace(" ", "") | |
| # -------------------------------------- | |
| # 不要文字列・空白の削除 | |
| remove_texts = ["\n", "\r", " "] | |
| for remove_text in remove_texts: | |
| content = content.replace(remove_text, "") | |
| # タブや連続空白をシングルスペースに変換 | |
| replace_texts = ["\t", "\u3000"] | |
| for replace_text in replace_texts: | |
| content = content.replace(replace_text, " ") | |
| # PDFの正当な改行をもとに戻す。 | |
| if ".pdf" in ref_documents[i].metadata['source']: | |
| content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ') | |
| ref_documents[i].page_content = content | |
| # -------------------------------------- | |
| # チャンクに分割 | |
| texts = text_splitter.split_documents(ref_documents) | |
| # -------------------------------------- | |
| # multi-e5 モデルの学習環境に合わせて文言を追加 | |
| # https://hironsan.hatenablog.com/entry/2023/07/05/073150 | |
| # -------------------------------------- | |
| if ss.current_embedding == "intfloat/multilingual-e5-large": | |
| for i in range(len(texts)): | |
| texts[i].page_content = "passage:" + texts[i].page_content | |
| # vectordb の初期化 | |
| if ss.db is None: | |
| ss = initialize_db(ss) | |
| # db に埋め込み | |
| # ss.db = Chroma.from_documents(texts, ss.embeddings) | |
| ss.db.add_documents(documents=texts, embedding=ss.embeddings) | |
| return ss | |
| def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str): | |
| # -------------------------------------- | |
| # モデルロード確認 | |
| # -------------------------------------- | |
| if ss.llm is None or ss.embeddings is None: | |
| status_message = "❌ LLM/Embeddingモデルが登録されていません。" | |
| return ss, status_message | |
| url_flag = "-" | |
| pdf_flag = "-" | |
| # -------------------------------------- | |
| # URLの読み込みとvectordb登録 | |
| # -------------------------------------- | |
| # URLリストの前処理(リスト化、重複削除、非URL排除) | |
| urls = list({url for url in urls.split("\n") if url and "://" in url}) | |
| if urls: | |
| # 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録 | |
| urls = [url for url in urls if url not in ss.embedded_urls] | |
| ss.embedded_urls.extend(urls) | |
| # ウェブページの読み込み | |
| loader = SeleniumURLLoader(urls=urls) | |
| ref_documents = loader.load() | |
| # 埋め込み処理の実行 | |
| ss = embedding_process(ss, ref_documents) | |
| url_flag = "✅ 登録済" | |
| # -------------------------------------- | |
| # PDFのヘッダーとフッターを除去してvectordb登録 | |
| # https://pypdf.readthedocs.io/en/stable/user/extract-text.html | |
| # -------------------------------------- | |
| if fileobj is None: | |
| pass | |
| else: | |
| # ファイル名リストを取得 | |
| pdf_paths = [] | |
| for path in fileobj: | |
| pdf_paths.append(path.name) | |
| # リストの初期化 | |
| ref_documents = [] | |
| # 各PDFファイルを読み込み | |
| for pdf_path in pdf_paths: | |
| pdf = PdfReader(pdf_path) | |
| body = [] | |
| def visitor_body(text, cm, tm, font_dict, font_size): | |
| y = tm[5] | |
| if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認 | |
| parts.append(text) | |
| for page in pdf.pages: | |
| parts = [] | |
| page.extract_text(visitor_text=visitor_body) | |
| body.append("".join(parts)) | |
| body = "\n".join(body) | |
| # パスからファイル名のみを取得 | |
| filename = os.path.basename(pdf_path) | |
| # 取得テキスト → LangChain ドキュメント変換 | |
| ref_documents.append(Document(page_content=body, metadata={"source": filename})) | |
| # 埋め込み処理の実行 | |
| ss = embedding_process(ss, ref_documents) | |
| pdf_flag = "✅ 登録済" | |
| langchain.debug=True | |
| status_message = "URL: " + url_flag + " / PDF: " + pdf_flag | |
| return ss, status_message | |
| def clear_db(ss: SessionState) -> (SessionState, str): | |
| if ss.db is None: | |
| status_message = "❌ 参照データが登録されていません。" | |
| return ss, status_message | |
| try: | |
| ss.db.delete_collection() | |
| status_message = "✅ 参照データを削除しました。" | |
| except NameError: | |
| status_message = "❌ 参照データが登録されていません。" | |
| return ss, status_message | |
| # ---------------------------------------------------------------------------- | |
| # query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面 | |
| # ⬇ ⬇ ⬆ | |
| # チャットボット画面 [qa_predict / conversation_predict] | |
| # ---------------------------------------------------------------------------- | |
| def user(ss: SessionState, query) -> (SessionState, list): | |
| # 会話履歴が一定数を超えた場合は、最初の履歴を削除する | |
| if len(ss.dialogue) > 20: | |
| ss.dialogue.pop(0) | |
| ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄) | |
| chat_history = ss.dialogue | |
| # チャット画面=chat_history | |
| return ss, chat_history | |
| def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (SessionState, str): | |
| original_query = query | |
| if ss.llm is None: | |
| if ss.dialogue: | |
| response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。" | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
| return ss, "" | |
| elif qa_flag is True and ss.embeddings is None: | |
| if ss.dialogue: | |
| response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。" | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
| return ss, "" | |
| elif qa_flag is True and ss.db is None: | |
| if ss.dialogue: | |
| response = "参照データが登録されていません。" | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
| return ss, "" | |
| # Refine query | |
| history = ss.memory.load_memory_variables({}) | |
| if history['chat_history'] != "": | |
| # チャット履歴からクエリをリファイン | |
| query = ss.query_generator({"query": query, "chat_history": history})['text'] | |
| # QA Model | |
| if qa_flag is True and ss.embeddings is not None and ss.db is not None: | |
| if web_flag: | |
| ss, web_query = web_search(ss, query) | |
| ss = qa_predict(ss, web_query) | |
| ss.memory.chat_memory.messages[-2].content = query | |
| else: | |
| ss = qa_predict(ss, query) | |
| # Chat Model | |
| else: | |
| if web_flag: | |
| ss, web_query = web_search(ss, query) | |
| ss = chat_predict(ss, web_query) | |
| ss.memory.chat_memory.messages[-2].content = query | |
| else: | |
| ss = chat_predict(ss, query) | |
| # GPTモデル利用時はDeepLでメモリを英語化 | |
| ss = deepl_memory(ss) | |
| return ss, "" # ssとquery欄(空欄) | |
| def chat_predict(ss: SessionState, query) -> SessionState: | |
| response = ss.conversation_chain.predict(query=query) | |
| ss.memory.chat_memory.add_user_message(query) | |
| ss.memory.chat_memory.add_ai_message(response) | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
| return ss | |
| def qa_predict(ss: SessionState, query) -> SessionState: | |
| original_query = query | |
| # Rinnaモデル向けの設定(クエリの改行コード修正) | |
| if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
| query = query.strip().replace("\n", "<NL>") | |
| else: | |
| query = query.strip() | |
| # multilingual-e5向けのクエリ文言prefix | |
| if ss.current_embedding == "intfloat/multilingual-e5-large": | |
| db_query_str = "query: " + query | |
| else: | |
| db_query_str = query | |
| # DBから関連文書と出典を抽出 | |
| docs = ss.db.similarity_search(db_query_str, k=ss.similarity_search_k) | |
| sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata))) | |
| # Rinnaモデル向けの設定(抽出文書の改行コード修正) | |
| if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
| for i in range(len(docs)): | |
| docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>") | |
| # 回答の生成(最大3回の試行) | |
| for _ in range(3): | |
| result = ss.qa_chain({"input_documents": docs, "query": query}) | |
| result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip() | |
| # result["output_text"]が空欄でない場合、メモリーを更新して返す | |
| if result["output_text"] != "": | |
| response = result["output_text"] + sources | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
| ss.memory.chat_memory.add_user_message(original_query) | |
| ss.memory.chat_memory.add_ai_message(response) | |
| return ss | |
| # else: | |
| # 空欄の場合は直近の履歴を削除してやり直し | |
| # ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2] | |
| # 3回の試行後も空欄の場合 | |
| response = "3回試行しましたが、情報製生成できませんでした。" | |
| if sources != "": | |
| response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。" | |
| # ユーザーメッセージと AI メッセージの追加 | |
| ss.memory.chat_memory.add_user_message(original_query.replace("<NL>", "\n")) | |
| ss.memory.chat_memory.add_ai_message(response) | |
| ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴 | |
| return ss | |
| # 回答を1文字ずつチャット画面に表示する | |
| def show_response(ss: SessionState) -> str: | |
| chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得 | |
| if chat_history: | |
| response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避 | |
| chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする | |
| if response is None: | |
| response = "回答を生成できませんでした。" | |
| for character in response: | |
| chat_history[-1][1] += character | |
| time.sleep(0.05) | |
| yield chat_history | |
| with gr.Blocks() as demo: | |
| # ユーザ別セッションメモリのインスタンス化(リロードでリセット) | |
| ss = gr.State(SessionState()) | |
| # -------------------------------------- | |
| # API KEY をセット/クリアする関数 | |
| # -------------------------------------- | |
| def openai_api_setfn(openai_api_key) -> str: | |
| if openai_api_key == "kikagaku": | |
| os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo") | |
| status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました" | |
| return status_message | |
| elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50: | |
| os.environ["OPENAI_API_KEY"] = "" | |
| status_message = "❌ 有効なAPIキーを入力してください" | |
| return status_message | |
| else: | |
| os.environ["OPENAI_API_KEY"] = openai_api_key | |
| status_message = "✅ APIキーを設定しました" | |
| return status_message | |
| def openai_api_clsfn(ss) -> (str, str): | |
| openai_api_key = "" | |
| os.environ["OPENAI_API_KEY"] = "" | |
| status_message = "✅ APIキーの削除が完了しました" | |
| return status_message, "" | |
| with gr.Tabs(): | |
| # -------------------------------------- | |
| # Setting Tab | |
| # -------------------------------------- | |
| with gr.TabItem("1. LLM設定"): | |
| with gr.Row(): | |
| model_id = gr.Dropdown( | |
| choices=[ | |
| 'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct', | |
| 'rinna/bilingual-gpt-neox-4b-instruction-sft', | |
| 'gpt-3.5-turbo', | |
| ], | |
| value="gpt-3.5-turbo", | |
| label='LLM model', | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| embedding_id = gr.Dropdown( | |
| choices=[ | |
| 'intfloat/multilingual-e5-large', | |
| 'sonoisa/sentence-bert-base-ja-mean-tokens-v2', | |
| 'oshizo/sbert-jsnli-luke-japanese-base-lite', | |
| 'text-embedding-ada-002', | |
| "None" | |
| ], | |
| value="text-embedding-ada-002", | |
| label = 'Embedding model', | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=19): | |
| openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1) | |
| with gr.Column(scale=1): | |
| openai_api_set = gr.Button(value="Set API KEY", size="sm") | |
| openai_api_cls = gr.Button(value="Delete API KEY", size="sm") | |
| # with gr.Row(): | |
| # reference_libs = gr.CheckboxGroup(choices=['LangChain', 'Gradio'], label="Reference Libraries", interactive=False) | |
| # 詳細設定(折りたたみ) | |
| with gr.Accordion(label="Advanced Setting", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True) | |
| verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True) | |
| with gr.Column(): | |
| temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True) | |
| with gr.Column(): | |
| similarity_search_k = gr.Slider(label="similarity_search_k (OpenAI, HF)", minimum=1, maximum=10, step=1, value=3, interactive=True) | |
| with gr.Column(): | |
| summarization_mode = gr.Radio(choices=['stuff', 'map_reduce'], label="Summarization mode", value='stuff', interactive=True) | |
| with gr.Column(): | |
| min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True) | |
| with gr.Column(): | |
| max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True) | |
| with gr.Column(): | |
| top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True) | |
| with gr.Column(): | |
| top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True) | |
| with gr.Column(): | |
| repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True) | |
| with gr.Column(): | |
| num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| config_btn = gr.Button(value="Configure") | |
| with gr.Column(scale=13): | |
| status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1) | |
| # ボタン等のアクション設定 | |
| openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
| openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full") | |
| openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
| config_btn.click( | |
| fn = load_models, | |
| inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature, \ | |
| similarity_search_k, summarization_mode, \ | |
| min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences], | |
| outputs = [ss, status_cfg], | |
| queue = True, | |
| show_progress = "full" | |
| ) | |
| # -------------------------------------- | |
| # Reference Tab | |
| # -------------------------------------- | |
| with gr.TabItem("2. References"): | |
| urls = gr.TextArea( | |
| max_lines = 60, | |
| show_label=False, | |
| info = "List any reference URLs for Q&A retrieval.", | |
| placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True) | |
| header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True) | |
| footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True) | |
| pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False) | |
| with gr.Row(): | |
| ref_set_btn = gr.Button(value="コンテンツ登録", scale=1) | |
| ref_clear_btn = gr.Button(value="登録データ削除", scale=1) | |
| status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18) | |
| ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full") | |
| ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full") | |
| # -------------------------------------- | |
| # Chatbot Tab | |
| # -------------------------------------- | |
| with gr.TabItem("3. Q&A Chat"): | |
| chat_history = gr.Chatbot([], elem_id="chatbot", avatar_images=["bear.png", "penguin.png"],) | |
| with gr.Row(): | |
| with gr.Column(scale=95): | |
| query = gr.Textbox( | |
| show_label=False, | |
| placeholder="Send a message with [Shift]+[Enter] key.", | |
| lines=4, | |
| container=False, | |
| autofocus=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=5): | |
| with gr.Row(): | |
| qa_flag = gr.Checkbox(label="QA mode", value=False, min_width=60, interactive=True) | |
| web_flag = gr.Checkbox(label="Web Search", value=True, min_width=60, interactive=True) | |
| with gr.Row(): | |
| query_send_btn = gr.Button(value="▶") | |
| # gr.Examples(["機械学習について説明してください"], inputs=[query]) | |
| query.submit( | |
| user, [ss, query], [ss, chat_history] | |
| ).then( | |
| bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query] | |
| ).then( | |
| show_response, [ss], [chat_history] | |
| ) | |
| query_send_btn.click( | |
| user, [ss, query], [ss, chat_history] | |
| ).then( | |
| bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query] | |
| ).then( | |
| show_response, [ss], [chat_history] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=5) | |
| demo.launch(debug=True) | |