Spaces:
Running
Running
| # conversational_rag.py - Adapter for AgriCritiqueRAG | |
| """ | |
| Lightweight conversational adapter that exposes a simple interface expected by the Gradio app: | |
| - ask(message) -> { 'answer': str, 'evidence': [...], 'session_id': str } | |
| - retrieve_with_context(question, history, top_k=5) -> list[evidence_dicts] | |
| - validate_answer(question, proposed_answer, evidence) -> str (critique) | |
| - start_session() -> session_id | |
| - conversation_manager property -> object with get_session_history(session_id) | |
| - current_session property -> session id (or None) | |
| This adapter intentionally keeps calls thin and compatible with older UI code. | |
| It also provides a factory function `create_rag_from_config` to build a RAG instance with optional overrides. | |
| """ | |
| from typing import List, Dict, Any, Optional | |
| import os | |
| import json | |
| # Try to import the main RAG implementation | |
| try: | |
| from rag_pipeline import AgriCritiqueRAG | |
| except Exception as e: | |
| AgriCritiqueRAG = None | |
| _import_error = e | |
| class ConversationRAGInterface: | |
| """ | |
| Adapter around AgriCritiqueRAG to ensure the interface matches what the Gradio app expects. | |
| """ | |
| def __init__(self, rag: Optional[AgriCritiqueRAG] = None, **kwargs): | |
| if rag is not None: | |
| self.rag = rag | |
| else: | |
| if AgriCritiqueRAG is None: | |
| raise ImportError(f"Could not import AgriCritiqueRAG: {_import_error}") | |
| # Create with kwargs forwarded to the RAG constructor | |
| self.rag = AgriCritiqueRAG(**kwargs) | |
| def ask(self, message: str) -> Dict[str, Any]: | |
| """ | |
| Top-level ask. Returns dict containing 'answer', 'evidence', and 'session_id'. | |
| """ | |
| return self.rag.ask(message) | |
| def retrieve_with_context(self, question: str, history: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: | |
| """ | |
| Proxy to underlying RAG's retrieval method. Returns enriched evidence dicts. | |
| """ | |
| return self.rag.retrieve_with_context(question, history, top_k=top_k) | |
| def validate_answer(self, question: str, proposed_answer: str, evidence: List[Dict[str, Any]]) -> str: | |
| """ | |
| Proxy to underlying RAG's validate_answer method. | |
| """ | |
| return self.rag.validate_answer(question, proposed_answer, evidence) | |
| def start_session(self, metadata: Optional[Dict[str, Any]] = None) -> str: | |
| """ | |
| Start a new conversation session and return session id. | |
| """ | |
| return self.rag.start_session(metadata) | |
| def current_session(self) -> Optional[str]: | |
| return getattr(self.rag, "current_session", None) | |
| def conversation_manager(self): | |
| return getattr(self.rag, "conversation_manager", None) | |
| # Convenience factory that can be used by higher-level code | |
| def create_rag_from_config(config_path: Optional[str] = None, **overrides) -> ConversationRAGInterface: | |
| """ | |
| Create a ConversationRAGInterface from an optional JSON config file and overrides. | |
| Example config file (json): | |
| { | |
| "base_model_id": "meta-llama/Llama-3.2-1B-Instruct", | |
| "model_id": "sayande/AgriScholarQA", | |
| "index_repo": "sayande/agri-critique-index", | |
| "metadata_local_path": "/path/to/meta.json" | |
| } | |
| Any keys passed as overrides will supersede the file. | |
| """ | |
| config = {} | |
| if config_path and os.path.exists(config_path): | |
| try: | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| except Exception as e: | |
| print(f"Warning: could not load config {config_path}: {e}") | |
| # Apply overrides | |
| config.update(overrides) | |
| # Instantiate the underlying RAG | |
| if AgriCritiqueRAG is None: | |
| raise ImportError(f"AgriCritiqueRAG is not available: {_import_error}") | |
| rag = AgriCritiqueRAG(**config) | |
| return ConversationRAGInterface(rag=rag) | |
| # If this module is run directly, create a small demo instance (no HF tokens required for retrieval-only) | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default=None, help='Path to JSON config file') | |
| args = parser.parse_args() | |
| try: | |
| iface = create_rag_from_config(args.config) | |
| sess = iface.start_session() | |
| print('Started session:', sess) | |
| q = 'What is the effect of drought on rice yield?' | |
| resp = iface.ask(q) | |
| print('Answer:', resp.get('answer')[:400]) | |
| print('Evidence count:', len(resp.get('evidence', []))) | |
| except Exception as e: | |
| print('Error creating RAG interface:', e) | |