Spaces:
Running
Running
| """ | |
| Full Conversational RAG Pipeline for Agri-Critique | |
| Includes: Session management, context-aware retrieval, memory management | |
| Loads everything from HuggingFace Hub | |
| """ | |
| import os | |
| import json | |
| import sqlite3 | |
| import uuid | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| import warnings | |
| # Suppress PEFT warnings about unexpected keys in LoraConfig | |
| warnings.filterwarnings("ignore", category=UserWarning, module="peft") | |
| class ConversationManager: | |
| """Manages conversation sessions with persistent storage""" | |
| def __init__(self, db_path="conversations.db"): | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path, check_same_thread=False) | |
| self.cursor = self.conn.cursor() | |
| self._init_db() | |
| def _init_db(self): | |
| """Initialize session database""" | |
| self.cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| created_at TEXT, | |
| last_updated TEXT, | |
| metadata TEXT | |
| ) | |
| """) | |
| self.cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| message_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT, | |
| role TEXT, | |
| content TEXT, | |
| timestamp TEXT, | |
| evidence TEXT, | |
| FOREIGN KEY (session_id) REFERENCES sessions(session_id) | |
| ) | |
| """) | |
| self.conn.commit() | |
| def create_session(self, metadata=None): | |
| """Create a new conversation session""" | |
| session_id = str(uuid.uuid4()) | |
| now = datetime.utcnow().isoformat() | |
| self.cursor.execute(""" | |
| INSERT INTO sessions (session_id, created_at, last_updated, metadata) | |
| VALUES (?, ?, ?, ?) | |
| """, (session_id, now, now, json.dumps(metadata or {}))) | |
| self.conn.commit() | |
| return session_id | |
| def add_message(self, session_id, role, content, evidence=None): | |
| """Add a message to a session""" | |
| now = datetime.utcnow().isoformat() | |
| self.cursor.execute(""" | |
| INSERT INTO messages (session_id, role, content, timestamp, evidence) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, role, content, now, json.dumps(evidence) if evidence else None)) | |
| # Update session timestamp | |
| self.cursor.execute(""" | |
| UPDATE sessions SET last_updated = ? WHERE session_id = ? | |
| """, (now, session_id)) | |
| self.conn.commit() | |
| def get_session_history(self, session_id, limit=None): | |
| """Get conversation history for a session""" | |
| query = """ | |
| SELECT role, content, timestamp, evidence | |
| FROM messages | |
| WHERE session_id = ? | |
| ORDER BY timestamp ASC | |
| """ | |
| if limit: | |
| query += f" LIMIT {limit}" | |
| self.cursor.execute(query, (session_id,)) | |
| messages = [] | |
| for row in self.cursor.fetchall(): | |
| messages.append({ | |
| 'role': row[0], | |
| 'content': row[1], | |
| 'timestamp': row[2], | |
| 'evidence': json.loads(row[3]) if row[3] else None | |
| }) | |
| return messages | |
| def summarize_old_messages(self, session_id, keep_recent=4): | |
| """Summarize old messages to save context window""" | |
| messages = self.get_session_history(session_id) | |
| if len(messages) <= keep_recent: | |
| return messages | |
| # Keep recent messages | |
| recent = messages[-keep_recent:] | |
| old = messages[:-keep_recent] | |
| # Create summary of old messages | |
| summary = "Previous conversation summary:\n" | |
| for msg in old[::2]: # Sample every other message | |
| summary += f"- {msg['role']}: {msg['content'][:100]}...\n" | |
| # Return summary + recent messages | |
| return [{'role': 'system', 'content': summary}] + recent | |
| class AgriCritiqueRAG: | |
| """Full RAG system with conversational capabilities""" | |
| def __init__(self): | |
| print("🔄 Initializing Agri-Critique Conversational RAG System...") | |
| # Model paths | |
| self.model_id = "sayande/AgriScholarQA-CoT" | |
| self.base_model_id = "Qwen/Qwen3-4B-Thinking-2507" | |
| self.index_repo = "sayande/agri-critique-index" | |
| # Conversation manager | |
| self.conversation_manager = ConversationManager() | |
| self.current_session = None | |
| # Load retriever | |
| print("📥 Loading retriever...") | |
| self.retriever = SentenceTransformer("all-mpnet-base-v2") | |
| # ------------------------------------------------------------------ | |
| # Load FAISS indices (local first, then HF fallback) | |
| # ------------------------------------------------------------------ | |
| print("📥 Loading FAISS indices...") | |
| self.chunk_index = None | |
| self.paper_index = None | |
| self.index = None # alias kept for backward compatibility | |
| base_dir = os.path.dirname(__file__) if "__file__" in globals() else os.getcwd() | |
| local_chunk_path = os.path.join(base_dir, "faiss.index") | |
| local_paper_path = os.path.join(base_dir, "faiss_papers.index") | |
| local_meta_path = os.path.join(base_dir, "meta.json") | |
| # ---- Try LOCAL chunk index ---- | |
| try: | |
| if os.path.exists(local_chunk_path): | |
| print(f"📁 Found local chunk index: {local_chunk_path}") | |
| self.chunk_index = faiss.read_index(local_chunk_path) | |
| self.index = self.chunk_index | |
| print(f"✅ Loaded local chunk FAISS index with {self.chunk_index.ntotal} vectors") | |
| else: | |
| print("ℹ️ Local chunk index 'faiss.index' not found, will try HuggingFace Hub...") | |
| except Exception as e: | |
| print(f"⚠️ Could not load local chunk index: {e}") | |
| self.chunk_index = None | |
| self.index = None | |
| # ---- If no local chunk index, fall back to HF ---- | |
| if self.chunk_index is None: | |
| print("📥 Loading FAISS index from HuggingFace dataset...") | |
| try: | |
| index_path = hf_hub_download( | |
| repo_id=self.index_repo, | |
| filename="faiss.index", | |
| repo_type="dataset" | |
| ) | |
| self.chunk_index = faiss.read_index(index_path) | |
| self.index = self.chunk_index | |
| print(f"✅ Loaded HF FAISS index with {self.chunk_index.ntotal} vectors") | |
| except Exception as e: | |
| print(f"⚠️ Could not load FAISS index from HF: {e}") | |
| self.chunk_index = None | |
| self.index = None | |
| # ---- Optional: paper-level index (not strictly required) ---- | |
| try: | |
| if os.path.exists(local_paper_path): | |
| print(f"📁 Found local paper index: {local_paper_path}") | |
| self.paper_index = faiss.read_index(local_paper_path) | |
| print(f"✅ Loaded local paper FAISS index with {self.paper_index.ntotal} vectors") | |
| else: | |
| print("ℹ️ Local paper index 'faiss_papers.index' not found (this is optional).") | |
| except Exception as e: | |
| print(f"⚠️ Could not load local paper index: {e}") | |
| self.paper_index = None | |
| # ------------------------------------------------------------------ | |
| # Load metadata (local first, then HF) | |
| # ------------------------------------------------------------------ | |
| print("📥 Loading metadata...") | |
| self.metadata = [] | |
| # Try local meta.json | |
| try: | |
| if os.path.exists(local_meta_path): | |
| print(f"📁 Found local metadata: {local_meta_path}") | |
| with open(local_meta_path, "r", encoding="utf-8") as f: | |
| self.metadata = json.load(f) | |
| print(f"✅ Loaded local metadata for {len(self.metadata)} chunks") | |
| else: | |
| print("ℹ️ Local 'meta.json' not found, will try HuggingFace Hub...") | |
| except Exception as e: | |
| print(f"⚠️ Could not load local metadata: {e}") | |
| self.metadata = [] | |
| # If still empty, try HF | |
| if not self.metadata: | |
| print("📥 Loading metadata from HuggingFace dataset...") | |
| try: | |
| meta_path = hf_hub_download( | |
| repo_id=self.index_repo, | |
| filename="meta.json", | |
| repo_type="dataset" | |
| ) | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| self.metadata = json.load(f) | |
| print(f"✅ Loaded HF metadata for {len(self.metadata)} chunks") | |
| except Exception as e: | |
| print(f"⚠️ Could not load metadata from HF: {e}") | |
| self.metadata = [] | |
| # Model will be loaded lazily on first use | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_loaded = False | |
| print("✅ Agri-Critique Conversational RAG System initialized!") | |
| print("ℹ️ Model will load on first query (Qwen3-4B with INT8 quantization)") | |
| def _ensure_model_loaded(self): | |
| """Lazy load model on first use""" | |
| if self.model_loaded: | |
| return | |
| print("📥 Loading Agri-Critique model (this may take 1-2 minutes)...") | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN not found. Please add it to Space secrets.") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.base_model_id, | |
| token=hf_token | |
| ) | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained(self.base_model_id, token=hf_token) | |
| # Qwen models work well with default config | |
| # No special rope_scaling adjustments needed | |
| print("🖥️ Loading Qwen3-4B model with INT4 quantization for speed") | |
| # Try to use INT4 quantization for faster inference (better for 4B models on CPU) | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # INT4 is better for larger models on CPU | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| print("✅ Using INT4 (NF4) quantization - optimized for Qwen 4B on CPU") | |
| except ImportError: | |
| print("⚠️ bitsandbytes not available, using float32") | |
| quantization_config = None | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| self.base_model_id, | |
| config=config, | |
| quantization_config=quantization_config, | |
| torch_dtype=torch.float32 if quantization_config is None else None, | |
| device_map="auto" if quantization_config else "cpu", | |
| low_cpu_mem_usage=True, | |
| token=hf_token, | |
| ) | |
| self.model = PeftModel.from_pretrained( | |
| base_model, | |
| self.model_id, | |
| token=hf_token, | |
| ) | |
| self.model.eval() | |
| self.model_loaded = True | |
| print("✅ Model loaded successfully!") | |
| def _refine_query_with_llm(self, query): | |
| """Use LLM to extract core search terms (Query Understanding/NER)""" | |
| if not self.model_loaded: | |
| return query # Can't refine if model not loaded yet | |
| prompt = [ | |
| {"role": "system", "content": "You are a search query optimizer. Extract ONLY the most important agricultural keywords, entities (crops, diseases, chemicals), and timeframes from the user's question. Return a concise string of keywords."}, | |
| {"role": "user", "content": f"Query: {query}"} | |
| ] | |
| try: | |
| input_text = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, # Short output | |
| temperature=0.3 | |
| ) | |
| refined = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() | |
| # print(f"DEBUG: Refined '{query}' -> '{refined}'") | |
| return refined | |
| except Exception: | |
| # Fallback | |
| return query | |
| def start_session(self, metadata=None): | |
| """Start a new conversation session""" | |
| self.current_session = self.conversation_manager.create_session(metadata) | |
| print(f"📝 Started new session: {self.current_session[:8]}...") | |
| return self.current_session | |
| def retrieve_with_context(self, query, conversation_history, top_k=5): | |
| """Context-aware retrieval: considers conversation history""" | |
| # Use chunk_index (local or HF). If missing, no retrieval. | |
| if self.chunk_index is None or not self.metadata: | |
| return [] | |
| # Combine current query with recent context | |
| context_queries = [query] | |
| # EXPERIMENTAL: Query Understanding / Refinement | |
| # If the model is already loaded, we can use it to "understand" the query | |
| # and extract better search terms (NER-lite). | |
| if self.model_loaded: | |
| refined = self._refine_query_with_llm(query) | |
| if refined and refined != query: | |
| # Add refined keywords with high importance | |
| context_queries.append(refined) | |
| # Add recent user questions for context | |
| for msg in conversation_history[-4:] if conversation_history else []: | |
| if msg["role"] == "user": | |
| context_queries.append(msg["content"]) | |
| # Encode all queries | |
| embeddings = self.retriever.encode(context_queries, convert_to_numpy=True) | |
| # Weighting: Original Query (high), Refined Query (med), History (low) | |
| # Dynamic weighting based on what we have | |
| num_q = len(context_queries) | |
| if num_q == 1: | |
| weights = [1.0] | |
| else: | |
| # Simple heuristic: First item (Original) gets 0.6 | |
| # Others share the remaining 0.4 | |
| weights = [0.6] + [0.4 / (num_q - 1)] * (num_q - 1) | |
| weighted_embedding = np.average(embeddings, axis=0, weights=weights).reshape(1, -1).astype("float32") | |
| faiss.normalize_L2(weighted_embedding) | |
| # Extract year from query (e.g., 2024, 2025) | |
| import re | |
| year_match = re.search(r"\b(20\d{2})\b", query) | |
| target_year = year_match.group(1) if year_match else None | |
| # Search over chunk index | |
| # Fetch more candidates to allow for temporal re-ranking | |
| # If year detected, fetch deep (e.g. 100) to find the year-match chunks | |
| if target_year: | |
| initial_k = 100 | |
| else: | |
| initial_k = top_k * 3 | |
| distances, indices = self.chunk_index.search(weighted_embedding.astype("float32"), initial_k) | |
| candidates = [] | |
| for idx, dist in zip(indices[0], distances[0]): | |
| if 0 <= idx < len(self.metadata): | |
| chunk_info = self.metadata[idx] | |
| # Check for year match in paper_id | |
| is_year_match = False | |
| if target_year and target_year in chunk_info.get("paper_id", ""): | |
| is_year_match = True | |
| candidates.append({ | |
| "data": chunk_info, | |
| "dist": float(dist), | |
| "is_year_match": is_year_match | |
| }) | |
| # Soft Boost Logic: | |
| # Instead of force-sorting year matches to the top (which brings in irrelevant junk), | |
| # we improve their distance score by a fixed amount (e.g., 0.5). | |
| # Assuming L2 distance (smaller is better): new_dist = old_dist - 0.5 | |
| # This lets a "Relevant Year-Match" beat "Relevant Non-Match", | |
| # but a "Totally Irrelevant Year-Match" will still lose to "Relevant Content". | |
| for cand in candidates: | |
| if cand["is_year_match"]: | |
| cand["effective_dist"] = cand["dist"] - 0.5 | |
| else: | |
| cand["effective_dist"] = cand["dist"] | |
| # Sort by effective distance (ascending) | |
| candidates.sort(key=lambda x: x["effective_dist"]) | |
| # Select top_k | |
| final_candidates = candidates[:top_k] | |
| evidence = [] | |
| for cand in final_candidates: | |
| ev = dict(cand["data"]) | |
| ev["score"] = cand["dist"] | |
| # FALLBACK: If 'text' is missing in metadata (common issue with this dataset version), | |
| # construct a proxy text from the section and paper ID so the RAG doesn't see empty strings. | |
| if "text" not in ev or not ev["text"]: | |
| paper = ev.get("paper_id", "Unknown Paper") | |
| sect = ev.get("section", "General") | |
| ev["text"] = f"[Note: Full text missing in metadata] Section '{sect}' from paper '{paper}'." | |
| evidence.append(ev) | |
| return evidence | |
| def _clean_paper_id(self, paper_id): | |
| """Clean paper ID for display""" | |
| if not isinstance(paper_id, str): | |
| return str(paper_id) | |
| clean = paper_id.strip("-_") | |
| clean = clean.replace("_", " ").replace("-", " ") | |
| return clean.title() | |
| def validate_and_answer(self, question, evidence, conversation_history): | |
| """Generate validated answer with reasoning - OPTIMIZED single-call version""" | |
| self._ensure_model_loaded() | |
| # Format evidence text for the model | |
| # Include title/paper_id so the model knows the source date/context | |
| evidence_text = "\n\n".join( | |
| [ | |
| f"[{i+1}] {ev.get('paper_title') or ev.get('paper_id')}\n{ev.get('text', '')}" | |
| for i, ev in enumerate(evidence) | |
| ] | |
| ) | |
| # OPTIMIZED: Single model call for both validation and answer | |
| # This reduces inference time by ~50% | |
| combined_messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are an agricultural research assistant. Your task is to:\n" | |
| "1. Validate the question against the evidence\n" | |
| "2. Provide a clear, comprehensive answer based ONLY on the evidence\n" | |
| "3. Cite sources as [1], [2], etc.\n\n" | |
| "Check: Is the question relevant? Are there conflicting facts? " | |
| "Is there enough information?" | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"""EVIDENCE: | |
| {evidence_text} | |
| QUESTION: {question} | |
| TASK: Provide a validated answer to the question. First briefly explain your reasoning, then give the final answer. Be detailed and thorough.""", | |
| }, | |
| ] | |
| input_text = self.tokenizer.apply_chat_template( | |
| combined_messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| # UPDATED: Increased token limit for more detailed answers | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=2048, # Increased from 256 for detail | |
| temperature=0.5, # Balanced temperature | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| full_response = self.tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=True, | |
| ).strip() | |
| # Split response into reasoning and answer | |
| # The model should naturally provide reasoning first, then answer | |
| if "\n\n" in full_response: | |
| parts = full_response.split("\n\n", 1) | |
| reasoning = parts[0] | |
| answer = parts[1] if len(parts) > 1 else full_response | |
| else: | |
| # Fallback: use entire response as answer | |
| reasoning = "Validated against evidence." | |
| answer = full_response | |
| return reasoning, answer | |
| def ask(self, question): | |
| """Full RAG pipeline with validation""" | |
| if not self.current_session: | |
| self.start_session() | |
| # Get conversation history | |
| history = self.conversation_manager.get_session_history(self.current_session) | |
| # Context-aware retrieval | |
| evidence = self.retrieve_with_context(question, history, top_k=5) | |
| # Generate validated answer | |
| reasoning, answer = self.validate_and_answer(question, evidence, history) | |
| # Save to session | |
| self.conversation_manager.add_message( | |
| self.current_session, "user", question, evidence | |
| ) | |
| self.conversation_manager.add_message( | |
| self.current_session, "assistant", answer | |
| ) | |
| # Format evidence for return (include all useful fields) | |
| formatted_evidence = [] | |
| for ev in evidence: | |
| paper_id = ev.get("paper_id", "unknown") | |
| display_paper = ev.get("paper_title") or self._clean_paper_id(paper_id) | |
| formatted_evidence.append( | |
| { | |
| "paper_id": display_paper, | |
| "raw_paper_id": paper_id, | |
| "text": ev.get("text", ""), | |
| "score": ev.get("score", 0.0), | |
| } | |
| ) | |
| return { | |
| "answer": answer, | |
| "reasoning": reasoning, | |
| "evidence": formatted_evidence, | |
| "session_id": self.current_session, | |
| } | |
| def validate_answer(self, question: str, proposed_answer: str, evidence: List[Dict[str, Any]]) -> str: | |
| """ | |
| Validate a proposed answer against evidence and return critique. | |
| This method uses the fine-tuned Llama model to critique the answer by checking: | |
| - Are all claims supported by the evidence? | |
| - Are there any hallucinations or fake findings? | |
| - Are citations accurate? | |
| - Are there temporal or causal errors? | |
| Args: | |
| question: The original question | |
| proposed_answer: The answer to validate | |
| evidence: List of evidence chunks with 'text' field | |
| Returns: | |
| Critique string identifying issues or confirming validity | |
| """ | |
| self._ensure_model_loaded() | |
| # Format evidence text for validation | |
| evidence_text = "\n\n".join([ | |
| f"[{i+1}] {ev.get('text', '')}" | |
| for i, ev in enumerate(evidence[:5]) # Limit to top 5 for context window | |
| ]) | |
| if not evidence_text.strip(): | |
| evidence_text = "(no evidence provided)" | |
| # Create validation prompt | |
| validation_messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a strict agricultural research validator. " | |
| "Your job is to critique the proposed answer by checking:\n" | |
| "1. Are all claims supported by the evidence?\n" | |
| "2. Are there any hallucinations or fake findings?\n" | |
| "3. Are citations accurate and properly used?\n" | |
| "4. Are there temporal or causal errors?\n" | |
| "5. Are there any unsupported extrapolations?\n\n" | |
| "Provide a concise critique. If the answer is well-supported, say so. " | |
| "If there are issues, clearly identify them." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"""QUESTION: {question} | |
| EVIDENCE: | |
| {evidence_text} | |
| PROPOSED ANSWER: | |
| {proposed_answer} | |
| TASK: Critique this answer. Identify any unsupported claims, hallucinations, citation errors, or other issues.""", | |
| }, | |
| ] | |
| # Generate critique using the model | |
| input_text = self.tokenizer.apply_chat_template( | |
| validation_messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.3, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| critique = self.tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ).strip() | |
| return critique | |