|
|
""" |
|
|
Intelligent Audit Report Chatbot UI |
|
|
""" |
|
|
|
|
|
import os |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*use_column_width.*") |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="streamlit") |
|
|
|
|
|
import time |
|
|
import json |
|
|
import uuid |
|
|
import logging |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
|
|
|
from collections import Counter |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import plotly.express as px |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
|
|
|
|
|
|
from src.agents import ( |
|
|
get_multi_agent_chatbot, |
|
|
get_smart_chatbot, |
|
|
get_gemini_chatbot, |
|
|
get_visual_chatbot, |
|
|
get_visual_multi_agent_chatbot |
|
|
) |
|
|
from src.feedback import FeedbackManager |
|
|
from src.ui_components import ( |
|
|
get_custom_css, |
|
|
display_chunk_statistics_charts, |
|
|
display_chunk_statistics_table, |
|
|
extract_chunk_statistics, |
|
|
display_visual_search_results |
|
|
) |
|
|
|
|
|
from src.config.paths import ( |
|
|
IS_DEPLOYED, |
|
|
PROJECT_DIR, |
|
|
HF_CACHE_DIR, |
|
|
FEEDBACK_DIR, |
|
|
CONVERSATIONS_DIR, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omp_threads = os.environ.get("OMP_NUM_THREADS", "") |
|
|
try: |
|
|
if omp_threads: |
|
|
|
|
|
|
|
|
cleaned = ''.join(filter(str.isdigit, omp_threads)) |
|
|
if cleaned: |
|
|
threads = int(cleaned) |
|
|
if threads <= 0: |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
else: |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = str(threads) |
|
|
else: |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
else: |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
except (ValueError, TypeError): |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if IS_DEPLOYED and HF_CACHE_DIR: |
|
|
cache_dir = str(HF_CACHE_DIR) |
|
|
os.environ["HF_HOME"] = cache_dir |
|
|
os.environ["TRANSFORMERS_CACHE"] = cache_dir |
|
|
os.environ["HF_DATASETS_CACHE"] = cache_dir |
|
|
os.environ["HF_HUB_CACHE"] = cache_dir |
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(cache_dir, mode=0o755, exist_ok=True) |
|
|
except (PermissionError, OSError): |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
layout="wide", |
|
|
page_icon="🤖", |
|
|
initial_sidebar_state="expanded", |
|
|
page_title="Intelligent Audit Report Chatbot" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
import torch, sys |
|
|
if "gpu_check" not in st.session_state: |
|
|
try: |
|
|
cuda_ = torch.cuda.is_available() |
|
|
mps_ = torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False |
|
|
if cuda_: |
|
|
print(f"🎮 CUDA available: {torch.cuda.get_device_name(0)}") |
|
|
elif mps_: |
|
|
print("🍎 MPS (Apple Silicon) available") |
|
|
else: |
|
|
print("💻 CPU only (no GPU acceleration)") |
|
|
except Exception as e: |
|
|
print(f"⚠️ GPU check error: {e}", file=sys.stderr) |
|
|
finally: |
|
|
st.session_state.gpu_check = True |
|
|
|
|
|
|
|
|
st.markdown(get_custom_css(), unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
def get_system_type(): |
|
|
"""Get the current system type""" |
|
|
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent') |
|
|
if system == 'smart': |
|
|
return "Smart Chatbot System" |
|
|
else: |
|
|
return "Multi-Agent System" |
|
|
|
|
|
def get_chatbot(version: str = "v1"): |
|
|
"""Initialize and return the chatbot based on version""" |
|
|
if version == "beta": |
|
|
return get_gemini_chatbot() |
|
|
elif version == "visual": |
|
|
|
|
|
return get_visual_multi_agent_chatbot() |
|
|
else: |
|
|
|
|
|
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent') |
|
|
if system == 'smart': |
|
|
return get_smart_chatbot() |
|
|
else: |
|
|
return get_multi_agent_chatbot() |
|
|
|
|
|
def serialize_messages(messages): |
|
|
"""Serialize LangChain messages to dictionaries""" |
|
|
serialized = [] |
|
|
for msg in messages: |
|
|
if hasattr(msg, 'content'): |
|
|
serialized.append({ |
|
|
"type": type(msg).__name__, |
|
|
"content": str(msg.content) |
|
|
}) |
|
|
return serialized |
|
|
|
|
|
def serialize_documents(sources): |
|
|
"""Serialize document objects to dictionaries with deduplication""" |
|
|
serialized = [] |
|
|
seen_content = set() |
|
|
|
|
|
for doc in sources: |
|
|
content = getattr(doc, 'page_content', getattr(doc, 'content', '')) |
|
|
|
|
|
|
|
|
if content in seen_content: |
|
|
continue |
|
|
|
|
|
seen_content.add(content) |
|
|
|
|
|
doc_dict = { |
|
|
"content": content, |
|
|
"metadata": getattr(doc, 'metadata', {}), |
|
|
"score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)), |
|
|
"id": getattr(doc, 'metadata', {}).get('_id', 'unknown'), |
|
|
"source": getattr(doc, 'metadata', {}).get('source', 'unknown'), |
|
|
"year": getattr(doc, 'metadata', {}).get('year', 'unknown'), |
|
|
"district": getattr(doc, 'metadata', {}).get('district', 'unknown'), |
|
|
"page": getattr(doc, 'metadata', {}).get('page', 'unknown'), |
|
|
"chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'), |
|
|
"page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'), |
|
|
"original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0), |
|
|
"reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None) |
|
|
} |
|
|
serialized.append(doc_dict) |
|
|
|
|
|
return serialized |
|
|
|
|
|
|
|
|
feedback_manager = FeedbackManager() |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_filter_options(): |
|
|
try: |
|
|
filter_options_path = PROJECT_DIR / "src" / "config" / "filter_options.json" |
|
|
with open(filter_options_path, "r") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
st.info(f"Looking for filter_options.json in: {PROJECT_DIR / 'src' / 'config'}") |
|
|
st.error("filter_options.json not found. Please run the metadata analysis script.") |
|
|
return {"sources": [], "years": [], "districts": [], 'filenames': []} |
|
|
|
|
|
def main(): |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if 'conversation_id' not in st.session_state: |
|
|
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}" |
|
|
if 'session_start_time' not in st.session_state: |
|
|
st.session_state.session_start_time = time.time() |
|
|
if 'active_filters' not in st.session_state: |
|
|
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []} |
|
|
|
|
|
if 'rag_retrieval_history' not in st.session_state: |
|
|
st.session_state.rag_retrieval_history = [] |
|
|
|
|
|
if 'chatbot_version' not in st.session_state: |
|
|
st.session_state.chatbot_version = "v1" |
|
|
|
|
|
|
|
|
chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}" |
|
|
|
|
|
|
|
|
needs_init = ( |
|
|
chatbot_version_key not in st.session_state or |
|
|
st.session_state.get('_last_version') != st.session_state.chatbot_version |
|
|
) |
|
|
|
|
|
if needs_init: |
|
|
try: |
|
|
|
|
|
if st.session_state.chatbot_version == "beta": |
|
|
spinner_msg = "🔄 Initializing Gemini FSA..." |
|
|
elif st.session_state.chatbot_version == "visual": |
|
|
spinner_msg = "🎨 Initializing Visual Search ... This may take 20-30 seconds..." |
|
|
else: |
|
|
spinner_msg = "🔄 Loading AI models and connecting to database..." |
|
|
|
|
|
with st.spinner(spinner_msg): |
|
|
st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version) |
|
|
st.session_state['_last_version'] = st.session_state.chatbot_version |
|
|
st.session_state.chatbot = st.session_state[chatbot_version_key] |
|
|
print("✅ AI system ready!") |
|
|
except Exception as e: |
|
|
st.error(f"❌ Failed to initialize chatbot: {str(e)}") |
|
|
|
|
|
if st.session_state.chatbot_version == "beta": |
|
|
st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)") |
|
|
elif st.session_state.chatbot_version == "visual": |
|
|
st.error("Please check your environment variables (QDRANT_URL, QDRANT_API_KEY, OPENAI_API_KEY for visual)") |
|
|
with st.expander("🐛 Debug Info"): |
|
|
import traceback |
|
|
st.code(traceback.format_exc()) |
|
|
else: |
|
|
st.error("Please check your configuration and ensure all required models and databases are accessible.") |
|
|
|
|
|
st.session_state.chatbot_version = "v1" |
|
|
st.session_state['_last_version'] = "v1" |
|
|
if 'chatbot' in st.session_state: |
|
|
del st.session_state['chatbot'] |
|
|
st.stop() |
|
|
else: |
|
|
|
|
|
st.session_state.chatbot = st.session_state[chatbot_version_key] |
|
|
|
|
|
|
|
|
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation: |
|
|
st.session_state.messages = [] |
|
|
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}" |
|
|
st.session_state.session_start_time = time.time() |
|
|
st.session_state.rag_retrieval_history = [] |
|
|
st.session_state.feedback_submitted = False |
|
|
st.session_state.reset_conversation = False |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
with col1: |
|
|
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True) |
|
|
with col2: |
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
|
selected_version = st.radio( |
|
|
"**Version:**", |
|
|
options=["v1", "visual", "beta"], |
|
|
index=0 if st.session_state.chatbot_version == "v1" else (1 if st.session_state.chatbot_version == "visual" else 2), |
|
|
horizontal=True, |
|
|
key="version_selector", |
|
|
help="Select v1 (default RAG), visual (ColPali visual search), or beta (Gemini FSA)" |
|
|
) |
|
|
|
|
|
|
|
|
if selected_version != st.session_state.chatbot_version: |
|
|
|
|
|
old_version = st.session_state.chatbot_version |
|
|
st.session_state.chatbot_version = selected_version |
|
|
|
|
|
|
|
|
new_chatbot_key = f"chatbot_{selected_version}" |
|
|
if new_chatbot_key in st.session_state: |
|
|
|
|
|
st.session_state.chatbot = st.session_state[new_chatbot_key] |
|
|
st.session_state['_last_version'] = selected_version |
|
|
else: |
|
|
|
|
|
st.session_state['_last_version'] = old_version |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
if st.session_state.chatbot_version == "beta": |
|
|
st.info("🔬 **Beta Mode**: Using Google Gemini FSA") |
|
|
elif st.session_state.chatbot_version == "visual": |
|
|
st.info("🎨 **Visual Mode**: Using Visual Search (Multi-Modal Embeddings)") |
|
|
|
|
|
|
|
|
duration = int(time.time() - st.session_state.session_start_time) |
|
|
duration_str = f"{duration // 60}m {duration % 60}s" |
|
|
st.markdown(f''' |
|
|
<div class="session-info"> |
|
|
<strong>Session Info:</strong> Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id} |
|
|
</div> |
|
|
''', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
filter_options = load_filter_options() |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
with st.expander("📖 How to Use", expanded=True): |
|
|
st.markdown(""" |
|
|
#### 🎯 Using Filters |
|
|
|
|
|
1. **Select filters** from the sidebar to narrow your search: |
|
|
|
|
|
2. **Leave filters empty** to search across all data |
|
|
|
|
|
3. **Type your question** in the chat input at the bottom |
|
|
|
|
|
4. **Click "Send"** to submit your question |
|
|
|
|
|
#### 💡 Tips |
|
|
|
|
|
- Use specific questions for better results |
|
|
- Combine multiple filters for precise searches |
|
|
- Check the "Retrieved Documents" tab to see source material |
|
|
|
|
|
#### ⚠️ Important |
|
|
|
|
|
**When finished, please close the browser window** to free up computational resources. |
|
|
|
|
|
--- |
|
|
|
|
|
For more detailed help, see the example questions at the bottom of the page. |
|
|
""") |
|
|
|
|
|
|
|
|
with st.expander("🔍 Search Filters", expanded=False): |
|
|
st.caption("Select filters to narrow down your search. Leave empty to search all data.") |
|
|
|
|
|
st.markdown('<div class="filter-section">', unsafe_allow_html=True) |
|
|
st.markdown('<div class="filter-title">📄 Specific Reports (Filename Filter)</div>', unsafe_allow_html=True) |
|
|
st.markdown('<p style="font-size: 0.85em; color: #666;">⚠️ Selecting specific reports will ignore all other filters</p>', unsafe_allow_html=True) |
|
|
selected_filenames = st.multiselect( |
|
|
"Select specific reports:", |
|
|
options=filter_options.get('filenames', []), |
|
|
default=st.session_state.active_filters.get('filenames', []), |
|
|
key="filenames_filter", |
|
|
help="Choose specific reports to search. When enabled, all other filters are ignored." |
|
|
) |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
filename_mode = len(selected_filenames) > 0 |
|
|
|
|
|
|
|
|
st.markdown('<div class="filter-title">📊 Sources</div>', unsafe_allow_html=True) |
|
|
selected_sources = st.multiselect( |
|
|
"Select sources:", |
|
|
options=filter_options['sources'], |
|
|
default=st.session_state.active_filters['sources'], |
|
|
disabled = filename_mode, |
|
|
key="sources_filter", |
|
|
help="Choose which types of reports to search" |
|
|
) |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown('<div class="filter-title">📅 Years</div>', unsafe_allow_html=True) |
|
|
selected_years = st.multiselect( |
|
|
"Select years:", |
|
|
options=filter_options['years'], |
|
|
default=st.session_state.active_filters['years'], |
|
|
disabled = filename_mode, |
|
|
key="years_filter", |
|
|
help="Choose which years to search" |
|
|
) |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True) |
|
|
selected_districts = st.multiselect( |
|
|
"Select districts:", |
|
|
options=filter_options['districts'], |
|
|
default=st.session_state.active_filters['districts'], |
|
|
disabled = filename_mode, |
|
|
key="districts_filter", |
|
|
help="Choose which districts to search" |
|
|
) |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if st.button("🗑️ Clear All Filters", key="clear_filters_button"): |
|
|
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []} |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
st.session_state.active_filters = { |
|
|
'sources': selected_sources if not filename_mode else [], |
|
|
'years': selected_years if not filename_mode else [], |
|
|
'districts': selected_districts if not filename_mode else [], |
|
|
'filenames': selected_filenames |
|
|
} |
|
|
|
|
|
|
|
|
if st.session_state.chatbot_version == "visual": |
|
|
with st.expander("🔥 Saliency Maps", expanded=False): |
|
|
st.caption("Visualize which image regions are relevant to your query") |
|
|
|
|
|
show_saliency = st.checkbox( |
|
|
"Enable Saliency Maps", |
|
|
value=st.session_state.get('show_saliency', False), |
|
|
key="saliency_toggle", |
|
|
help="Generate heatmaps showing which parts of each document are most relevant" |
|
|
) |
|
|
st.session_state.show_saliency = show_saliency |
|
|
|
|
|
if show_saliency: |
|
|
|
|
|
colormap_options = ["hot", "jet", "viridis", "plasma", "coolwarm", "RdYlGn"] |
|
|
saliency_colormap = st.selectbox( |
|
|
"Colormap", |
|
|
options=colormap_options, |
|
|
index=colormap_options.index(st.session_state.get('saliency_colormap', 'hot')), |
|
|
key="saliency_colormap_select", |
|
|
help="Color scheme for the heatmap. 'hot' recommended for visibility." |
|
|
) |
|
|
st.session_state.saliency_colormap = saliency_colormap |
|
|
|
|
|
saliency_alpha = st.slider( |
|
|
"Overlay Transparency", |
|
|
min_value=0.1, |
|
|
max_value=0.8, |
|
|
value=st.session_state.get('saliency_alpha', 0.4), |
|
|
step=0.1, |
|
|
key="saliency_alpha_slider", |
|
|
help="0.1 = subtle, 0.8 = intense" |
|
|
) |
|
|
st.session_state.saliency_alpha = saliency_alpha |
|
|
|
|
|
saliency_threshold = st.slider( |
|
|
"Threshold (%)", |
|
|
min_value=0, |
|
|
max_value=80, |
|
|
value=st.session_state.get('saliency_threshold', 50), |
|
|
step=10, |
|
|
key="saliency_threshold_slider", |
|
|
help="Hide patches below this percentile" |
|
|
) |
|
|
st.session_state.saliency_threshold = saliency_threshold |
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["💬 Chat", "📄 Retrieved Documents"]) |
|
|
|
|
|
with tab1: |
|
|
|
|
|
chat_container = st.container() |
|
|
|
|
|
with chat_container: |
|
|
|
|
|
for message in st.session_state.messages: |
|
|
if isinstance(message, HumanMessage): |
|
|
st.markdown(f'<div class="user-message">{message.content}</div>', unsafe_allow_html=True) |
|
|
elif isinstance(message, AIMessage): |
|
|
st.markdown(f'<div class="bot-message">{message.content}</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([4, 1]) |
|
|
|
|
|
with col1: |
|
|
|
|
|
if 'input_counter' not in st.session_state: |
|
|
st.session_state.input_counter = 0 |
|
|
|
|
|
|
|
|
if 'pending_question' in st.session_state and st.session_state.pending_question: |
|
|
default_value = st.session_state.pending_question |
|
|
|
|
|
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000 |
|
|
del st.session_state.pending_question |
|
|
key_suffix = st.session_state.input_counter |
|
|
else: |
|
|
default_value = "" |
|
|
key_suffix = st.session_state.input_counter |
|
|
|
|
|
user_input = st.text_input( |
|
|
"Type your message here...", |
|
|
placeholder="Ask about budget allocations, expenditures, or audit findings...", |
|
|
key=f"user_input_{key_suffix}", |
|
|
label_visibility="collapsed", |
|
|
value=default_value if default_value else None |
|
|
) |
|
|
|
|
|
with col2: |
|
|
send_button = st.button("Send", key="send_button", use_container_width=True) |
|
|
|
|
|
|
|
|
if st.button("🗑️ Clear Chat", key="clear_chat_button"): |
|
|
st.session_state.reset_conversation = True |
|
|
|
|
|
conversations_path = CONVERSATIONS_DIR |
|
|
if conversations_path.exists(): |
|
|
for file in conversations_path.iterdir(): |
|
|
if file.suffix == '.json': |
|
|
file.unlink() |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
if send_button and user_input: |
|
|
|
|
|
filter_context_str = "" |
|
|
if selected_filenames: |
|
|
filter_context_str += "FILTER CONTEXT:\n" |
|
|
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n" |
|
|
filter_context_str += "USER QUERY:\n" |
|
|
elif selected_sources or selected_years or selected_districts: |
|
|
filter_context_str += "FILTER CONTEXT:\n" |
|
|
if selected_sources: |
|
|
filter_context_str += f"Sources: {', '.join(selected_sources)}\n" |
|
|
if selected_years: |
|
|
filter_context_str += f"Years: {', '.join(selected_years)}\n" |
|
|
if selected_districts: |
|
|
filter_context_str += f"Districts: {', '.join(selected_districts)}\n" |
|
|
filter_context_str += "USER QUERY:\n" |
|
|
|
|
|
full_query = filter_context_str + user_input |
|
|
|
|
|
|
|
|
st.session_state.messages.append(HumanMessage(content=user_input)) |
|
|
|
|
|
|
|
|
with st.spinner("🤔 Thinking..."): |
|
|
try: |
|
|
|
|
|
chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id) |
|
|
|
|
|
|
|
|
if isinstance(chat_result, dict): |
|
|
response = chat_result['response'] |
|
|
rag_result = chat_result.get('rag_result') |
|
|
st.session_state.last_rag_result = rag_result |
|
|
|
|
|
|
|
|
if rag_result: |
|
|
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else []) |
|
|
|
|
|
|
|
|
if not sources or len(sources) == 0: |
|
|
gemini_result = chat_result.get('gemini_result') |
|
|
print(f"🔍 DEBUG: Checking gemini_result for sources...") |
|
|
print(f" gemini_result exists: {gemini_result is not None}") |
|
|
if gemini_result: |
|
|
print(f" gemini_result type: {type(gemini_result)}") |
|
|
print(f" has sources attr: {hasattr(gemini_result, 'sources')}") |
|
|
if hasattr(gemini_result, 'sources'): |
|
|
print(f" sources length: {len(gemini_result.sources) if gemini_result.sources else 0}") |
|
|
|
|
|
if gemini_result and hasattr(gemini_result, 'sources'): |
|
|
|
|
|
if hasattr(st.session_state.chatbot, 'gemini_client'): |
|
|
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result) |
|
|
print(f"✅ Formatted {len(sources)} sources from gemini_client") |
|
|
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'): |
|
|
sources = st.session_state.chatbot._format_gemini_sources(gemini_result) |
|
|
print(f"✅ Formatted {len(sources)} sources from _format_gemini_sources") |
|
|
|
|
|
|
|
|
if sources and len(sources) > 0: |
|
|
if isinstance(rag_result, dict): |
|
|
rag_result['sources'] = sources |
|
|
elif hasattr(rag_result, 'sources'): |
|
|
rag_result.sources = sources |
|
|
|
|
|
st.session_state.last_rag_result = rag_result |
|
|
print(f"✅ Updated rag_result with {len(sources)} sources") |
|
|
|
|
|
|
|
|
actual_rag_query = chat_result.get('actual_rag_query', '') |
|
|
if actual_rag_query: |
|
|
|
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) |
|
|
formatted_query = f"{timestamp} - INFO - 🔍 ACTUAL RAG QUERY: '{actual_rag_query}'" |
|
|
else: |
|
|
formatted_query = "No RAG query available" |
|
|
|
|
|
|
|
|
filters_used = { |
|
|
"sources": st.session_state.active_filters.get('sources', []), |
|
|
"years": st.session_state.active_filters.get('years', []), |
|
|
"districts": st.session_state.active_filters.get('districts', []), |
|
|
"filenames": st.session_state.active_filters.get('filenames', []) |
|
|
} |
|
|
|
|
|
retrieval_entry = { |
|
|
"conversation_up_to": serialize_messages(st.session_state.messages), |
|
|
"rag_query_expansion": formatted_query, |
|
|
"docs_retrieved": serialize_documents(sources), |
|
|
"filters_applied": filters_used, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
st.session_state.rag_retrieval_history.append(retrieval_entry) |
|
|
|
|
|
|
|
|
print(f"📊 RETRIEVAL TRACKING: {len(sources)} sources stored in retrieval history") |
|
|
else: |
|
|
response = chat_result |
|
|
st.session_state.last_rag_result = None |
|
|
|
|
|
|
|
|
st.session_state.messages.append(AIMessage(content=response)) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Sorry, I encountered an error: {str(e)}" |
|
|
st.session_state.messages.append(AIMessage(content=error_msg)) |
|
|
|
|
|
|
|
|
st.session_state.input_counter += 1 |
|
|
st.rerun() |
|
|
|
|
|
with tab2: |
|
|
|
|
|
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result: |
|
|
rag_result = st.session_state.last_rag_result |
|
|
|
|
|
|
|
|
sources = None |
|
|
if hasattr(rag_result, 'sources'): |
|
|
|
|
|
sources = rag_result.sources |
|
|
elif isinstance(rag_result, dict) and 'sources' in rag_result: |
|
|
|
|
|
sources = rag_result['sources'] |
|
|
|
|
|
|
|
|
if (not sources or len(sources) == 0) and isinstance(rag_result, dict): |
|
|
gemini_result = rag_result.get('gemini_result') |
|
|
if gemini_result and hasattr(gemini_result, 'sources'): |
|
|
|
|
|
if hasattr(st.session_state.chatbot, 'gemini_client'): |
|
|
sources = st.session_state.chatbot.gemini_client.format_sources_for_display(gemini_result) |
|
|
elif hasattr(st.session_state.chatbot, '_format_gemini_sources'): |
|
|
sources = st.session_state.chatbot._format_gemini_sources(gemini_result) |
|
|
|
|
|
|
|
|
is_visual_search = False |
|
|
if sources and len(sources) > 0: |
|
|
first_doc_metadata = getattr(sources[0], 'metadata', {}) |
|
|
is_visual_search = 'num_tiles' in first_doc_metadata or 'num_visual_tokens' in first_doc_metadata |
|
|
|
|
|
if sources and len(sources) > 0: |
|
|
|
|
|
if is_visual_search and st.session_state.chatbot_version == "visual": |
|
|
st.markdown("### 🎨 Visual Search Results") |
|
|
|
|
|
|
|
|
show_saliency = st.session_state.get('show_saliency', False) |
|
|
saliency_alpha = st.session_state.get('saliency_alpha', 0.4) |
|
|
saliency_threshold = st.session_state.get('saliency_threshold', 50) |
|
|
saliency_colormap = st.session_state.get('saliency_colormap', 'hot') |
|
|
|
|
|
|
|
|
qdrant_client = None |
|
|
collection_name = None |
|
|
query_embedding = None |
|
|
|
|
|
if show_saliency: |
|
|
try: |
|
|
|
|
|
chatbot = st.session_state.get('chatbot') |
|
|
if chatbot and hasattr(chatbot, 'visual_search'): |
|
|
visual_search = chatbot.visual_search |
|
|
qdrant_client = visual_search.client |
|
|
collection_name = visual_search.collection_name |
|
|
query_embedding = visual_search.last_query_embedding |
|
|
|
|
|
if query_embedding is None: |
|
|
st.warning("⚠️ Query embedding not available for saliency") |
|
|
show_saliency = False |
|
|
else: |
|
|
logger.info(f"✅ Saliency enabled: colormap={saliency_colormap}, alpha={saliency_alpha}, threshold={saliency_threshold}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get saliency requirements: {e}") |
|
|
st.warning(f"⚠️ Saliency unavailable: {str(e)[:50]}") |
|
|
show_saliency = False |
|
|
|
|
|
|
|
|
stats = extract_chunk_statistics(sources) |
|
|
|
|
|
|
|
|
if len(sources) >= 5: |
|
|
display_chunk_statistics_charts(stats, "Retrieval Statistics") |
|
|
st.markdown("---") |
|
|
|
|
|
display_visual_search_results( |
|
|
sources=sources, |
|
|
show_statistics=True, |
|
|
show_images=True, |
|
|
show_saliency=show_saliency, |
|
|
qdrant_client=qdrant_client, |
|
|
collection_name=collection_name, |
|
|
query_embedding=query_embedding, |
|
|
saliency_alpha=saliency_alpha, |
|
|
saliency_colormap=saliency_colormap, |
|
|
saliency_threshold=saliency_threshold, |
|
|
max_display=20 |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
unique_filenames = set() |
|
|
for doc in sources: |
|
|
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown') |
|
|
unique_filenames.add(filename) |
|
|
|
|
|
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**") |
|
|
if len(unique_filenames) < len(sources): |
|
|
st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.") |
|
|
|
|
|
|
|
|
stats = extract_chunk_statistics(sources) |
|
|
|
|
|
|
|
|
if len(sources) >= 10: |
|
|
display_chunk_statistics_charts(stats, "Retrieval Statistics") |
|
|
|
|
|
st.markdown("---") |
|
|
display_chunk_statistics_table(stats, "Retrieval Distribution") |
|
|
else: |
|
|
display_chunk_statistics_table(stats, "Retrieval Distribution") |
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### 📄 Document Details") |
|
|
|
|
|
for i, doc in enumerate(sources): |
|
|
|
|
|
metadata = getattr(doc, 'metadata', {}) |
|
|
|
|
|
score = metadata.get('reranked_score') or metadata.get('original_score') or metadata.get('score') |
|
|
chunk_id = metadata.get('_id') or metadata.get('chunk_id', 'Unknown') |
|
|
if score is not None: |
|
|
try: |
|
|
score_text = f" (Score: {float(score):.3f})" |
|
|
except (ValueError, TypeError): |
|
|
score_text = "" |
|
|
else: |
|
|
score_text = "" |
|
|
if chunk_id and chunk_id != 'Unknown': |
|
|
score_text += f" (ID: {str(chunk_id)[:8]}...)" if score_text else f" (ID: {str(chunk_id)[:8]}...)" |
|
|
|
|
|
with st.expander(f"📄 Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"): |
|
|
|
|
|
metadata = getattr(doc, 'metadata', {}) |
|
|
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1]) |
|
|
|
|
|
with col1: |
|
|
st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}") |
|
|
with col2: |
|
|
st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}") |
|
|
with col3: |
|
|
st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}") |
|
|
with col4: |
|
|
|
|
|
page = metadata.get('page_label', metadata.get('page', 'Unknown')) |
|
|
chunk_id = metadata.get('_id', 'Unknown') |
|
|
st.write(f"📖 **Page:** {page}") |
|
|
st.write(f"🆔 **ID:** {chunk_id}") |
|
|
|
|
|
|
|
|
content = getattr(doc, 'page_content', 'No content available') |
|
|
st.write(f"**Full Content:**") |
|
|
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}") |
|
|
else: |
|
|
st.info("No documents were retrieved for the last query.") |
|
|
else: |
|
|
st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### 💬 Feedback Dashboard") |
|
|
|
|
|
|
|
|
has_conversation = len(st.session_state.messages) > 0 |
|
|
has_retrievals = len(st.session_state.rag_retrieval_history) > 0 |
|
|
|
|
|
if not has_conversation: |
|
|
st.info("💡 Start a conversation to provide feedback!") |
|
|
st.markdown("The feedback dashboard will be enabled once you begin chatting.") |
|
|
else: |
|
|
st.markdown("Help us improve by providing feedback on this conversation.") |
|
|
|
|
|
|
|
|
if 'feedback_submitted' not in st.session_state: |
|
|
st.session_state.feedback_submitted = False |
|
|
|
|
|
|
|
|
if not st.session_state.feedback_submitted: |
|
|
with st.form("feedback_form", clear_on_submit=False): |
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
|
|
with col1: |
|
|
feedback_score = st.slider( |
|
|
"Rate this conversation (1-5)", |
|
|
min_value=1, |
|
|
max_value=5, |
|
|
help="How satisfied are you with the conversation?" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
is_feedback_about_last_retrieval = st.checkbox( |
|
|
"Feedback about last retrieval only", |
|
|
value=True, |
|
|
help="If checked, feedback applies to the most recent document retrieval" |
|
|
) |
|
|
|
|
|
open_ended_feedback = st.text_area( |
|
|
"Your feedback (optional)", |
|
|
placeholder="Tell us what went well or what could be improved...", |
|
|
height=100 |
|
|
) |
|
|
|
|
|
|
|
|
submit_disabled = feedback_score is None |
|
|
|
|
|
submitted = st.form_submit_button( |
|
|
"📤 Submit Feedback", |
|
|
use_container_width=True, |
|
|
disabled=submit_disabled |
|
|
) |
|
|
|
|
|
if submitted: |
|
|
|
|
|
print("=" * 80) |
|
|
print("🔄 FEEDBACK SUBMISSION: Starting...") |
|
|
print("=" * 80) |
|
|
st.write("🔍 **Debug: Feedback Data Being Submitted:**") |
|
|
|
|
|
|
|
|
transcript = feedback_manager.extract_transcript(st.session_state.messages) |
|
|
|
|
|
|
|
|
retrievals = feedback_manager.build_retrievals_structure( |
|
|
|
|
|
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [], |
|
|
st.session_state.messages |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
feedback_score_related_retrieval_docs = feedback_manager.build_feedback_score_related_retrieval_docs( |
|
|
is_feedback_about_last_retrieval, |
|
|
st.session_state.messages, |
|
|
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [] |
|
|
) |
|
|
|
|
|
|
|
|
retrieved_data_old_format = st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [] |
|
|
|
|
|
|
|
|
feedback_dict = { |
|
|
"open_ended_feedback": open_ended_feedback, |
|
|
"score": feedback_score, |
|
|
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval, |
|
|
"conversation_id": st.session_state.conversation_id, |
|
|
"timestamp": time.time(), |
|
|
"message_count": len(st.session_state.messages), |
|
|
"has_retrievals": has_retrievals, |
|
|
"retrieval_count": len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0, |
|
|
"transcript": transcript, |
|
|
"retrievals": retrievals, |
|
|
"feedback_score_related_retrieval_docs": feedback_score_related_retrieval_docs, |
|
|
"retrieved_data": retrieved_data_old_format |
|
|
} |
|
|
|
|
|
print(f"📝 FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}") |
|
|
|
|
|
|
|
|
feedback_obj = None |
|
|
try: |
|
|
feedback_obj = feedback_manager.create_feedback_from_dict(feedback_dict) |
|
|
print(f"✅ FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}") |
|
|
st.write(f"✅ **Feedback Object Created**") |
|
|
st.write(f"- Feedback ID: {feedback_obj.feedback_id}") |
|
|
st.write(f"- Score: {feedback_obj.score}/5") |
|
|
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}") |
|
|
|
|
|
|
|
|
feedback_data = feedback_obj.to_dict() |
|
|
except Exception as e: |
|
|
print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}") |
|
|
st.error(f"Failed to create feedback object: {e}") |
|
|
feedback_data = feedback_dict |
|
|
|
|
|
|
|
|
st.json(feedback_data) |
|
|
|
|
|
|
|
|
feedback_dir = FEEDBACK_DIR |
|
|
try: |
|
|
|
|
|
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True) |
|
|
except (PermissionError, OSError) as e: |
|
|
logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}") |
|
|
|
|
|
feedback_dir = Path("feedback") |
|
|
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True) |
|
|
|
|
|
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json" |
|
|
|
|
|
try: |
|
|
|
|
|
feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"💾 FEEDBACK SAVE: Saving to local file: {feedback_file}") |
|
|
with open(feedback_file, 'w') as f: |
|
|
json.dump(feedback_data, f, indent=2, default=str) |
|
|
|
|
|
print(f"✅ FEEDBACK SAVE: Local file saved successfully") |
|
|
|
|
|
|
|
|
logger.info("🔄 FEEDBACK SAVE: Starting Snowflake save process...") |
|
|
logger.info(f"📊 FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}") |
|
|
|
|
|
snowflake_success = False |
|
|
try: |
|
|
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true" |
|
|
logger.info(f"🔍 SNOWFLAKE CHECK: enabled={snowflake_enabled}") |
|
|
|
|
|
if snowflake_enabled: |
|
|
if feedback_obj: |
|
|
try: |
|
|
logger.info("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...") |
|
|
print("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...") |
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("💾 Saving feedback to Snowflake... This may take 10-15 seconds (connecting to database, preparing data, and executing query)"): |
|
|
snowflake_success = feedback_manager.save_to_snowflake(feedback_obj) |
|
|
|
|
|
if snowflake_success: |
|
|
logger.info("✅ SNOWFLAKE UI: Successfully saved to Snowflake") |
|
|
print("✅ SNOWFLAKE UI: Successfully saved to Snowflake") |
|
|
else: |
|
|
logger.warning("⚠️ SNOWFLAKE UI: Save failed") |
|
|
print("⚠️ SNOWFLAKE UI: Save failed") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ SNOWFLAKE UI ERROR: {e}") |
|
|
print(f"❌ SNOWFLAKE UI ERROR: {e}") |
|
|
traceback.print_exc() |
|
|
snowflake_success = False |
|
|
else: |
|
|
logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") |
|
|
print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") |
|
|
snowflake_success = False |
|
|
else: |
|
|
logger.info("💡 SNOWFLAKE UI: Integration disabled") |
|
|
print("💡 SNOWFLAKE UI: Integration disabled") |
|
|
|
|
|
snowflake_success = True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") |
|
|
print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") |
|
|
snowflake_success = False |
|
|
|
|
|
|
|
|
if snowflake_success: |
|
|
st.success("✅ Thank you for your feedback! It has been saved successfully.") |
|
|
st.balloons() |
|
|
else: |
|
|
st.warning("⚠️ Feedback saved locally, but Snowflake save failed. Please check logs.") |
|
|
|
|
|
|
|
|
st.session_state.feedback_submitted = True |
|
|
|
|
|
print("=" * 80) |
|
|
print(f"✅ FEEDBACK SUBMISSION: Completed successfully") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
st.info(f"📁 Feedback saved to: {feedback_file}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}") |
|
|
print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}") |
|
|
traceback.print_exc() |
|
|
st.error(f"❌ Error saving feedback: {e}") |
|
|
st.write(f"Debug error: {str(e)}") |
|
|
else: |
|
|
|
|
|
st.success("✅ Feedback already submitted for this conversation!") |
|
|
col1, col2 = st.columns([1, 1]) |
|
|
with col1: |
|
|
if st.button("🔄 Submit New Feedback", key="new_feedback_button", use_container_width=True): |
|
|
try: |
|
|
st.session_state.feedback_submitted = False |
|
|
st.rerun() |
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error resetting feedback state: {e}") |
|
|
st.error(f"Error resetting feedback. Please refresh the page.") |
|
|
with col2: |
|
|
if st.button("📋 View Conversation", key="view_conversation_button", use_container_width=True): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if st.session_state.rag_retrieval_history: |
|
|
st.markdown("---") |
|
|
st.markdown("#### 📊 Retrieval History") |
|
|
|
|
|
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False): |
|
|
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1): |
|
|
st.markdown(f"### **Retrieval #{idx}**") |
|
|
|
|
|
|
|
|
if entry.get("timestamp"): |
|
|
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(entry["timestamp"])) |
|
|
st.caption(f"🕐 {timestamp_str}") |
|
|
|
|
|
|
|
|
rag_query_expansion = entry.get("rag_query_expansion", "No query available") |
|
|
st.markdown("**🔍 RAG Query:**") |
|
|
st.code(rag_query_expansion, language="text") |
|
|
|
|
|
|
|
|
filters_applied = entry.get("filters_applied", {}) |
|
|
if filters_applied and any(filters_applied.values()): |
|
|
st.markdown("**🎯 Filters Applied:**") |
|
|
filter_display = {} |
|
|
if filters_applied.get("sources"): |
|
|
filter_display["Sources"] = filters_applied["sources"] |
|
|
if filters_applied.get("years"): |
|
|
filter_display["Years"] = filters_applied["years"] |
|
|
if filters_applied.get("districts"): |
|
|
filter_display["Districts"] = filters_applied["districts"] |
|
|
if filters_applied.get("filenames"): |
|
|
filter_display["Filenames"] = filters_applied["filenames"] |
|
|
|
|
|
if filter_display: |
|
|
st.json(filter_display) |
|
|
else: |
|
|
st.info("No filters applied") |
|
|
else: |
|
|
st.info("No filters applied") |
|
|
|
|
|
|
|
|
conversation_up_to = entry.get("conversation_up_to", []) |
|
|
if conversation_up_to: |
|
|
st.markdown("**💬 Conversation History (up to retrieval point):**") |
|
|
with st.expander(f"View {len(conversation_up_to)} messages", expanded=False): |
|
|
for msg_idx, msg in enumerate(conversation_up_to, 1): |
|
|
role = msg.get("type", "unknown") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if role == "HumanMessage" or role == "human": |
|
|
st.markdown(f"**👤 User {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}") |
|
|
elif role == "AIMessage" or role == "ai": |
|
|
st.markdown(f"**🤖 Assistant {msg_idx}:** {content[:200]}{'...' if len(content) > 200 else ''}") |
|
|
else: |
|
|
st.info("No conversation history available") |
|
|
|
|
|
|
|
|
docs_retrieved = entry.get("docs_retrieved", []) |
|
|
if docs_retrieved: |
|
|
st.markdown(f"**📄 Documents Retrieved ({len(docs_retrieved)}):**") |
|
|
with st.expander(f"View {len(docs_retrieved)} documents", expanded=False): |
|
|
for doc_idx, doc in enumerate(docs_retrieved, 1): |
|
|
st.markdown(f"**Document {doc_idx}:**") |
|
|
|
|
|
|
|
|
metadata = doc.get("metadata", {}) |
|
|
if metadata: |
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}") |
|
|
with col2: |
|
|
st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}") |
|
|
with col3: |
|
|
st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}") |
|
|
|
|
|
|
|
|
if metadata.get('district'): |
|
|
st.write(f"📍 **District:** {metadata.get('district')}") |
|
|
if metadata.get('page'): |
|
|
st.write(f"📖 **Page:** {metadata.get('page')}") |
|
|
if metadata.get('score') is not None: |
|
|
st.write(f"⭐ **Score:** {metadata.get('score'):.3f}" if isinstance(metadata.get('score'), (int, float)) else f"⭐ **Score:** {metadata.get('score')}") |
|
|
|
|
|
|
|
|
content = doc.get("content", doc.get("page_content", "")) |
|
|
if content: |
|
|
st.markdown("**Content Preview:**") |
|
|
st.text_area( |
|
|
"Content Preview", |
|
|
value=content[:200] + ("..." if len(content) > 200 else ""), |
|
|
height=100, |
|
|
disabled=True, |
|
|
label_visibility="collapsed", |
|
|
key=f"retrieval_{idx}_doc_{doc_idx}_preview" |
|
|
) |
|
|
|
|
|
if doc_idx < len(docs_retrieved): |
|
|
st.markdown("---") |
|
|
else: |
|
|
st.info("No documents retrieved") |
|
|
|
|
|
|
|
|
st.markdown("**📊 Summary:**") |
|
|
st.json({ |
|
|
"conversation_length": len(conversation_up_to), |
|
|
"documents_retrieved": len(docs_retrieved) |
|
|
}) |
|
|
|
|
|
if idx < len(st.session_state.rag_retrieval_history): |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
if 'custom_question_1' not in st.session_state: |
|
|
st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?" |
|
|
if 'custom_question_2' not in st.session_state: |
|
|
st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?" |
|
|
|
|
|
|
|
|
header_col, q1_col = st.columns([1, 2]) |
|
|
|
|
|
with header_col: |
|
|
st.markdown("### 💡 Example Questions") |
|
|
st.caption(" Click **Use ...** or edit") |
|
|
|
|
|
with q1_col: |
|
|
example_q1 = "List couple of insights from the filename." |
|
|
st.markdown("**📄 File Insights** _(select a file first)_") |
|
|
q1_inner1, q1_inner2 = st.columns([3, 1]) |
|
|
with q1_inner1: |
|
|
st.code(example_q1, language=None) |
|
|
with q1_inner2: |
|
|
if st.button("📋 Use question !", key="use_example_1", use_container_width=True): |
|
|
st.session_state.pending_question = example_q1 |
|
|
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000 |
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("#### ✏️ Customizable Questions") |
|
|
q_col1, q_col2 = st.columns(2) |
|
|
|
|
|
|
|
|
with q_col1: |
|
|
st.caption("🔄 _This question will trigger follow-up prompts for year/district_") |
|
|
custom_q1 = st.text_area( |
|
|
"Question 2:", |
|
|
value=st.session_state.custom_question_1, |
|
|
height=100, |
|
|
key="edit_question_2", |
|
|
help="Modify this question to fit your needs" |
|
|
) |
|
|
if st.button("📋 Use Question 2", key="use_custom_1", use_container_width=True): |
|
|
if custom_q1.strip(): |
|
|
st.session_state.pending_question = custom_q1.strip() |
|
|
st.session_state.custom_question_1 = custom_q1.strip() |
|
|
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000 |
|
|
st.rerun() |
|
|
else: |
|
|
st.warning("Please enter a question first!") |
|
|
|
|
|
|
|
|
with q_col2: |
|
|
st.caption("✅ _Complete question - has year & context, no follow-up needed_") |
|
|
custom_q2 = st.text_area( |
|
|
"Question 3:", |
|
|
value=st.session_state.custom_question_2, |
|
|
height=100, |
|
|
key="edit_question_3", |
|
|
help="Modify this question to fit your needs" |
|
|
) |
|
|
if st.button("📋 Use Question 3", key="use_custom_2", use_container_width=True): |
|
|
if custom_q2.strip(): |
|
|
st.session_state.pending_question = custom_q2.strip() |
|
|
st.session_state.custom_question_2 = custom_q2.strip() |
|
|
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000 |
|
|
st.rerun() |
|
|
else: |
|
|
st.warning("Please enter a question first!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<script> |
|
|
window.scrollTo(0, document.body.scrollHeight); |
|
|
</script> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
try: |
|
|
from streamlit.runtime.scriptrunner import get_script_run_ctx |
|
|
if get_script_run_ctx() is None: |
|
|
|
|
|
print("=" * 80) |
|
|
print("⚠️ WARNING: This is a Streamlit app!") |
|
|
print("=" * 80) |
|
|
print("\nPlease run this app using:") |
|
|
print(" streamlit run app.py") |
|
|
print("\nNot: python app.py") |
|
|
print("\nThe app will not function correctly when run with 'python app.py'") |
|
|
print("=" * 80) |
|
|
import sys |
|
|
sys.exit(1) |
|
|
except ImportError: |
|
|
|
|
|
print("=" * 80) |
|
|
print("⚠️ WARNING: This is a Streamlit app!") |
|
|
print("=" * 80) |
|
|
print("\nPlease run this app using:") |
|
|
print(" streamlit run app.py") |
|
|
print("\nNot: python app.py") |
|
|
print("=" * 80) |
|
|
import sys |
|
|
sys.exit(1) |
|
|
main() |
|
|
|