Spaces:
Sleeping
Sleeping
| ########################################################################################### | |
| # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB | |
| # Author: Andreas Fischer | |
| # Date: October 10th, 2024 | |
| # Last update: October 10th, 2024 | |
| ########################################################################################## | |
| import os | |
| import chromadb | |
| from datetime import datetime | |
| from chromadb import Documents, EmbeddingFunction, Embeddings | |
| from chromadb.utils import embedding_functions | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") | |
| device='cuda' if torch.cuda.is_available() else 'cpu' | |
| #device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu' | |
| jina.to(device) #cuda:0 | |
| print(device) | |
| class JinaEmbeddingFunction(EmbeddingFunction): | |
| def __call__(self, input: Documents) -> Embeddings: | |
| embeddings = jina.encode(input) #max_length=2048 | |
| return(embeddings.tolist()) | |
| dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db" | |
| onPrem = True if(os.path.exists(dbPath)) else False | |
| if(onPrem==False): dbPath="/home/user/app/db" | |
| #onPrem=True # uncomment to override automatic detection | |
| print(dbPath) | |
| path=dbPath | |
| client = chromadb.PersistentClient(path=path) | |
| print(client.heartbeat()) | |
| print(client.get_version()) | |
| print(client.list_collections()) | |
| jina_ef=JinaEmbeddingFunction() | |
| embeddingModel=jina_ef | |
| from huggingface_hub import InferenceClient | |
| import gradio as gr | |
| import json | |
| inferenceClient = InferenceClient( | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| #"mistralai/Mistral-7B-Instruct-v0.1" | |
| ) | |
| def format_prompt(message, history): | |
| prompt = "<s>" | |
| #for user_prompt, bot_response in history: | |
| # prompt += f"[INST] {user_prompt} [/INST]" | |
| # prompt += f" {bot_response}</s> " | |
| prompt += f"[INST] {message} [/INST]" | |
| return prompt | |
| from pypdf import PdfReader | |
| import ocrmypdf | |
| def convertPDF(pdf_file, allow_ocr=False): | |
| reader = PdfReader(pdf_file) | |
| full_text = "" | |
| page_list = [] | |
| def extract_text_from_pdf(reader): | |
| full_text = "" | |
| page_list = [] | |
| page_count = 1 | |
| for idx, page in enumerate(reader.pages): | |
| text = page.extract_text() | |
| if len(text) > 0: | |
| page_list.append(text) | |
| #full_text += f"---- Page {idx} ----\n" + text + "\n\n" | |
| page_count += 1 | |
| return full_text.strip(), page_count, page_list | |
| # Check if there are any images | |
| image_count = sum(len(page.images) for page in reader.pages) | |
| # If there are images and not much content, perform OCR on the document | |
| if allow_ocr: | |
| print(f"{image_count} Images") | |
| if image_count > 0 and len(full_text) < 1000: | |
| out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") | |
| ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) | |
| reader = PdfReader(out_pdf_file) | |
| # Extract text: | |
| full_text, page_count, page_list = extract_text_from_pdf(reader) | |
| l = len(page_list) | |
| print(f"{l} Pages") | |
| # Extract metadata | |
| metadata = { | |
| "author": reader.metadata.author, | |
| "creator": reader.metadata.creator, | |
| "producer": reader.metadata.producer, | |
| "subject": reader.metadata.subject, | |
| "title": reader.metadata.title, | |
| "image_count": image_count, | |
| "page_count": page_count, | |
| "char_count": len(full_text), | |
| } | |
| return page_list, full_text, metadata | |
| def split_with_overlap(text,chunk_size=3500, overlap=700): | |
| chunks=[] | |
| step=max(1,chunk_size-overlap) | |
| for i in range(0,len(text),step): | |
| end=min(i+chunk_size,len(text)) | |
| #chunk = text[i:i+chunk_size] | |
| chunks.append(text[i:end]) | |
| return chunks | |
| def add_doc(path): | |
| print("def add_doc!") | |
| print(path) | |
| if(str.lower(path).endswith(".pdf")): | |
| doc=convertPDF(path) | |
| doc="\n\n".join(doc[0]) | |
| gr.Info("PDF uploaded, start Indexing!") | |
| else: | |
| gr.Info("Error: Only pdfs are accepted!") | |
| client = chromadb.PersistentClient(path="output/general_knowledge") | |
| print(str(client.list_collections())) | |
| #global collection | |
| dbName="test" | |
| if("name="+dbName in str(client.list_collections())): | |
| client.delete_collection(name=dbName) | |
| collection = client.create_collection( | |
| dbName, | |
| embedding_function=embeddingModel, | |
| metadata={"hnsw:space": "cosine"}) | |
| corpus=split_with_overlap(doc,3500,700) | |
| print(len(corpus)) | |
| then = datetime.now() | |
| x=collection.get(include=[])["ids"] | |
| print(len(x)) | |
| if(len(x)==0): | |
| chunkSize=40000 | |
| for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts | |
| print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) | |
| ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) | |
| batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] | |
| textIDs=[str(id) for id in ids[0:len(batch)]] | |
| ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID | |
| collection.add(documents=batch, ids=ids, | |
| metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, | |
| print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) | |
| now = datetime.now() | |
| gr.Info(f"Indexing complete!") | |
| print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks | |
| return(collection) | |
| #split_with_overlap("test me if you can",2,1) | |
| import gradio as gr | |
| import re | |
| def multimodalResponse(message,history,headerPattern,sentenceWiseSplitting): | |
| print("def multimodal response!") | |
| length=str(len(history)) | |
| query=message["text"] | |
| if(len(message["files"])>0): # is there at least one file attached? | |
| collection=add_doc(message["files"][0]) | |
| client = chromadb.PersistentClient(path="output/general_knowledge") | |
| print(str(client.list_collections())) | |
| x=collection.get(include=[])["ids"] | |
| context=collection.query(query_texts=[query], n_results=1) | |
| print(str(context)) | |
| #context=["<context "+str(i+1)+">\n"+c+"\n</context "+str(i+1)+">" for i, c in enumerate(retrievedTexts)] | |
| #context="\n\n".join(context) | |
| #return context | |
| if temperature < 1e-2: temperature = 1e-2 | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| temperature=float(0.9), | |
| max_new_tokens=5000, | |
| top_p=0.95, | |
| repetition_penalty=1.0, | |
| do_sample=True, | |
| seed=42, | |
| ) | |
| system="Given the following conversation, relevant context, and a follow up question, "+\ | |
| "reply with an answer to the current question the user is asking. "+\ | |
| "Return only your response to the question given the above information "+\ | |
| "following the users instructions as needed.\n\nContext:"+\ | |
| str(context) | |
| print(system) | |
| formatted_prompt = format_prompt(system+"\n"+prompt, history) | |
| stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
| output = "" | |
| for response in stream: | |
| output += response.token.text | |
| yield output | |
| #output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>" | |
| yield output | |
| i=gr.ChatInterface(multimodalResponse, | |
| title="pdfChatbot", | |
| multimodal=True, | |
| additional_inputs=[ | |
| gr.Dropdown( | |
| info="select retrieval version", | |
| choices=["1","2","3"], | |
| value=["1"], | |
| label="Retrieval Version")]) | |
| i.launch() #allowed_paths=["."]) | |