AgriScholarQA / conversational_rag.py
sayande's picture
Update conversational_rag.py
7cde118 verified
# conversational_rag.py - Adapter for AgriCritiqueRAG
"""
Lightweight conversational adapter that exposes a simple interface expected by the Gradio app:
- ask(message) -> { 'answer': str, 'evidence': [...], 'session_id': str }
- retrieve_with_context(question, history, top_k=5) -> list[evidence_dicts]
- validate_answer(question, proposed_answer, evidence) -> str (critique)
- start_session() -> session_id
- conversation_manager property -> object with get_session_history(session_id)
- current_session property -> session id (or None)
This adapter intentionally keeps calls thin and compatible with older UI code.
It also provides a factory function `create_rag_from_config` to build a RAG instance with optional overrides.
"""
from typing import List, Dict, Any, Optional
import os
import json
# Try to import the main RAG implementation
try:
from rag_pipeline import AgriCritiqueRAG
except Exception as e:
AgriCritiqueRAG = None
_import_error = e
class ConversationRAGInterface:
"""
Adapter around AgriCritiqueRAG to ensure the interface matches what the Gradio app expects.
"""
def __init__(self, rag: Optional[AgriCritiqueRAG] = None, **kwargs):
if rag is not None:
self.rag = rag
else:
if AgriCritiqueRAG is None:
raise ImportError(f"Could not import AgriCritiqueRAG: {_import_error}")
# Create with kwargs forwarded to the RAG constructor
self.rag = AgriCritiqueRAG(**kwargs)
def ask(self, message: str) -> Dict[str, Any]:
"""
Top-level ask. Returns dict containing 'answer', 'evidence', and 'session_id'.
"""
return self.rag.ask(message)
def retrieve_with_context(self, question: str, history: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
"""
Proxy to underlying RAG's retrieval method. Returns enriched evidence dicts.
"""
return self.rag.retrieve_with_context(question, history, top_k=top_k)
def validate_answer(self, question: str, proposed_answer: str, evidence: List[Dict[str, Any]]) -> str:
"""
Proxy to underlying RAG's validate_answer method.
"""
return self.rag.validate_answer(question, proposed_answer, evidence)
def start_session(self, metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Start a new conversation session and return session id.
"""
return self.rag.start_session(metadata)
@property
def current_session(self) -> Optional[str]:
return getattr(self.rag, "current_session", None)
@property
def conversation_manager(self):
return getattr(self.rag, "conversation_manager", None)
# Convenience factory that can be used by higher-level code
def create_rag_from_config(config_path: Optional[str] = None, **overrides) -> ConversationRAGInterface:
"""
Create a ConversationRAGInterface from an optional JSON config file and overrides.
Example config file (json):
{
"base_model_id": "meta-llama/Llama-3.2-1B-Instruct",
"model_id": "sayande/AgriScholarQA",
"index_repo": "sayande/agri-critique-index",
"metadata_local_path": "/path/to/meta.json"
}
Any keys passed as overrides will supersede the file.
"""
config = {}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
except Exception as e:
print(f"Warning: could not load config {config_path}: {e}")
# Apply overrides
config.update(overrides)
# Instantiate the underlying RAG
if AgriCritiqueRAG is None:
raise ImportError(f"AgriCritiqueRAG is not available: {_import_error}")
rag = AgriCritiqueRAG(**config)
return ConversationRAGInterface(rag=rag)
# If this module is run directly, create a small demo instance (no HF tokens required for retrieval-only)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None, help='Path to JSON config file')
args = parser.parse_args()
try:
iface = create_rag_from_config(args.config)
sess = iface.start_session()
print('Started session:', sess)
q = 'What is the effect of drought on rice yield?'
resp = iface.ask(q)
print('Answer:', resp.get('answer')[:400])
print('Evidence count:', len(resp.get('evidence', [])))
except Exception as e:
print('Error creating RAG interface:', e)