| | import streamlit as st |
| | from PIL import Image |
| | from transformers import BlipProcessor, BlipForConditionalGeneration |
| | import pandas as pd |
| | import os |
| |
|
| | from evaluate import MisinformationPredictor |
| | from src.evidence.im2im_retrieval import ImageCorpus |
| | from src.evidence.text2text_retrieval import SemanticSimilarity |
| | from src.utils.path_utils import get_project_root |
| | from typing import List, Optional, Tuple |
| | from dataclasses import dataclass |
| |
|
| | |
| | processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
| | model = BlipForConditionalGeneration.from_pretrained( |
| | "Salesforce/blip-image-captioning-large" |
| | ) |
| |
|
| | PROJECT_ROOT = get_project_root() |
| |
|
| |
|
| | @dataclass |
| | class Evidence: |
| | evidence_id: str |
| | dataset: str |
| | text: Optional[str] |
| | image: Optional[Image.Image] |
| | caption: Optional[str] |
| | image_path: Optional[str] |
| | classification_result_all: Optional[Tuple[str, str, str, str]] = None |
| | classification_result_final: Optional[str] = None |
| |
|
| |
|
| | CLASSIFICATION_CATEGORIES = ["support", "refute", "not_enough_information"] |
| |
|
| |
|
| | def generate_caption(image: Image.Image) -> str: |
| | """Generates a caption for a given image.""" |
| | try: |
| | with st.spinner("Generating caption..."): |
| | inputs = processor(image, return_tensors="pt") |
| | output = model.generate(**inputs) |
| | return processor.decode(output[0], skip_special_tokens=True) |
| | except Exception as e: |
| | st.error(f"Error generating caption: {e}") |
| | return "" |
| |
|
| |
|
| | def enrich_text_with_caption(text: str, image_caption: str) -> str: |
| | """Appends the image caption to the given text.""" |
| | if image_caption: |
| | return f"{text}. {image_caption}" |
| | return text |
| |
|
| |
|
| | @st.cache_data |
| | def get_train_df(): |
| | data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed") |
| | train_csv_path = os.path.join(data_dir, "train_enriched.csv") |
| | return pd.read_csv(train_csv_path) |
| |
|
| |
|
| | @st.cache_data |
| | def get_test_df(): |
| | data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed") |
| | train_csv_path = os.path.join(data_dir, "test_enriched.csv") |
| | return pd.read_csv(train_csv_path) |
| |
|
| |
|
| | @st.cache_data |
| | def get_semantic_similarity( |
| | train_embeddings_file: str, |
| | test_embeddings_file: str, |
| | train_df: pd.DataFrame, |
| | test_df: pd.DataFrame, |
| | ): |
| | return SemanticSimilarity( |
| | train_embeddings_file=train_embeddings_file, |
| | test_embeddings_file=test_embeddings_file, |
| | train_df=train_df, |
| | test_df=test_df, |
| | ) |
| |
|
| |
|
| | def retrieve_evidences_by_text( |
| | query: str, |
| | top_k: int = 5, |
| | ) -> List[Evidence]: |
| | """ |
| | Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity. |
| | |
| | Args: |
| | query (str): The query text to perform the search. |
| | top_k (int): Number of top results to retrieve. |
| | |
| | Returns: |
| | List[Evidence]: A list of retrieved evidence objects. |
| | """ |
| | train_embeddings_file = os.path.join(PROJECT_ROOT, "train_embeddings.h5") |
| | test_embeddings_file = os.path.join(PROJECT_ROOT, "test_embeddings.h5") |
| | similarity = get_semantic_similarity( |
| | train_embeddings_file=train_embeddings_file, |
| | test_embeddings_file=test_embeddings_file, |
| | train_df=get_train_df(), |
| | test_df=get_test_df(), |
| | ) |
| | evidences = [] |
| | try: |
| | |
| | results = similarity.search(query=query, top_k=top_k) |
| |
|
| | |
| | for evidence_id, score in results: |
| | |
| | if evidence_id.startswith("train_"): |
| | df = similarity.train_csv |
| | elif evidence_id.startswith("test_"): |
| | df = similarity.test_csv |
| | else: |
| | continue |
| |
|
| | |
| | row = df[df["id"] == int(evidence_id.split("_")[1])].iloc[0] |
| | evidence_text = row.get("evidence_enriched") |
| | evidence_image_caption = row.get("evidence_image_caption") |
| | evidence_image_path = row.get("evidence_image") |
| | evidence_image = None |
| | full_image_path = None |
| |
|
| | |
| | if pd.notna(evidence_image_path): |
| | full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path) |
| | try: |
| | evidence_image = Image.open(full_image_path).convert("RGB") |
| | except Exception as e: |
| | st.error(f"Failed to load image {evidence_image_path}: {e}") |
| |
|
| | evidence_id_number = evidence_id.split("_")[1] |
| | evidence_dataset = evidence_id.split("_")[0] |
| |
|
| | |
| | evidences.append( |
| | Evidence( |
| | text=evidence_text, |
| | image=evidence_image, |
| | caption=evidence_image_caption, |
| | evidence_id=evidence_id_number, |
| | dataset=evidence_dataset, |
| | image_path=full_image_path, |
| | ) |
| | ) |
| | except Exception as e: |
| | st.error(f"Error performing semantic search: {e}") |
| |
|
| | return evidences |
| |
|
| |
|
| | @st.cache_data |
| | def get_image_corpus(image_features): |
| | return ImageCorpus(image_features) |
| |
|
| |
|
| | def retrieve_evidences_by_image( |
| | image_path: str, |
| | top_k: int = 5, |
| | ) -> List[Evidence]: |
| | """ |
| | Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity. |
| | |
| | Args: |
| | query (str): The query text to perform the search. |
| | top_k (int): Number of top results to retrieve. |
| | |
| | Returns: |
| | List[Evidence]: A list of retrieved evidence objects. |
| | """ |
| | image_features = os.path.join(PROJECT_ROOT, "evidence_features.pkl") |
| | image_corpus = get_image_corpus(image_features) |
| | evidences = [] |
| | try: |
| | |
| | results = image_corpus.retrieve_similar_images(image_path, top_k=top_k) |
| |
|
| | |
| | for evidence_path, score in results: |
| | evidence_id = evidence_path.split("/")[-1] |
| | evidence_id_number = evidence_id.split("_")[0] |
| | |
| | if "train" in evidence_path: |
| | df = get_train_df() |
| | elif "test" in evidence_path: |
| | df = get_test_df() |
| | else: |
| | continue |
| |
|
| | |
| | row = df[df["id"] == int(evidence_id_number)].iloc[0] |
| | evidence_text = row.get("evidence_enriched") |
| | evidence_image_caption = row.get("evidence_image_caption") |
| | evidence_image_path = row.get("evidence_image") |
| | evidence_image = None |
| | full_image_path = None |
| |
|
| | |
| | if pd.notna(evidence_image_path): |
| | full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path) |
| | try: |
| | evidence_image = Image.open(full_image_path).convert("RGB") |
| | except Exception as e: |
| | st.error(f"Failed to load image {evidence_image_path}: {e}") |
| |
|
| | |
| | evidences.append( |
| | Evidence( |
| | text=evidence_text, |
| | image=evidence_image, |
| | caption=evidence_image_caption, |
| | dataset=evidence_path.split("/")[-2], |
| | evidence_id=evidence_id_number, |
| | image_path=full_image_path, |
| | ) |
| | ) |
| | except Exception as e: |
| | st.error(f"Error performing semantic search: {e}") |
| |
|
| | return evidences |
| |
|
| |
|
| | @st.cache_resource |
| | def get_predictor(): |
| | return MisinformationPredictor(model_path="ckpts/model.pt", device="cpu") |
| |
|
| |
|
| | def classify_evidence( |
| | claim_text: str, claim_image_path: str, evidence_text: str, evidence_image_path: str |
| | ) -> Tuple[str, str, str, str]: |
| | """Assigns a random classification to each evidence.""" |
| | predictor = get_predictor() |
| | predictions = predictor.evaluate( |
| | claim_text, claim_image_path, evidence_text, evidence_image_path |
| | ) |
| | if predictions: |
| | return ( |
| | predictions.get("text_text", "not_enough_information"), |
| | predictions.get("text_image", "not_enough_information"), |
| | predictions.get("image_text", "not_enough_information"), |
| | predictions.get("image_image", "not_enough_information"), |
| | ) |
| | else: |
| | return ( |
| | "not_enough_information", |
| | "not_enough_information", |
| | "not_enough_information", |
| | "not_enough_information", |
| | ) |
| |
|
| |
|
| | def display_evidence_tab(evidences: List[Evidence], tab_label: str): |
| | """Displays evidence in a tabbed format.""" |
| | with st.container(): |
| | for index, evidence in enumerate(evidences): |
| | with st.container(): |
| | st.subheader(f"Evidence {index + 1}") |
| | st.write(f"Evidence Dataset: {evidence.dataset}") |
| | st.write(f"Evidence ID: {evidence.evidence_id}") |
| | if evidence.image: |
| | st.image( |
| | evidence.image, |
| | caption="Evidence Image", |
| | use_container_width=True, |
| | ) |
| | st.text_area( |
| | "Evidence Caption", |
| | value=evidence.caption or "No caption available.", |
| | height=100, |
| | key=f"caption_{tab_label}_{index}", |
| | disabled=True, |
| | ) |
| | st.text_area( |
| | "Evidence Text", |
| | value=evidence.text or "No text available.", |
| | height=100, |
| | key=f"text_{tab_label}_{index}", |
| | disabled=True, |
| | ) |
| | if evidence.classification_result_all: |
| | st.write("**Classification:**") |
| | st.write(f"**text|text:** {evidence.classification_result_all[0]}") |
| | st.write(f"**text|image:** {evidence.classification_result_all[1]}") |
| | st.write(f"**image|text:** {evidence.classification_result_all[2]}") |
| | st.write( |
| | f"**image|image:** {evidence.classification_result_all[3]}" |
| | ) |
| | st.write( |
| | f"**Final classification result:** {evidence.classification_result_final}" |
| | ) |
| |
|
| |
|
| | def get_final_classification(results: Tuple[str, str, str, str]) -> str: |
| | text_text = results[0] |
| | text_image = results[1] |
| | image_text = results[2] |
| | image_image = results[3] |
| |
|
| | |
| | def resolve_classification(val1: str, val2: str) -> str: |
| | if val1 == val2 and val1 in {"support", "refute"}: |
| | return val1 |
| | if (val1 in {"support", "refute"} and val2 == "not_enough_information") or ( |
| | val2 in {"support", "refute"} and val1 == "not_enough_information" |
| | ): |
| | return val1 if val1 != "not_enough_information" else val2 |
| | return "not_enough_information" |
| |
|
| | |
| | final_result = resolve_classification(text_text, image_image) |
| | if final_result != "not_enough_information": |
| | return final_result |
| |
|
| | |
| | final_result = resolve_classification(text_image, image_text) |
| | if final_result != "not_enough_information": |
| | return final_result |
| |
|
| | |
| | return "not_enough_information" |
| |
|
| |
|
| | def main(): |
| | st.title("Multimodal Evidence-Based Misinformation Classification") |
| | st.write("Upload claims that have image and/or text content to verify.") |
| |
|
| | |
| | uploaded_image = st.file_uploader( |
| | "Upload an image (1 max)", type=["jpg", "jpeg", "png"], key="image_uploader" |
| | ) |
| |
|
| | if uploaded_image: |
| | try: |
| | image = Image.open(uploaded_image).convert("RGB") |
| | st.image(image, caption="Uploaded Image", use_container_width=True) |
| | except Exception as e: |
| | st.error(f"Failed to display the image: {e}") |
| |
|
| | |
| | input_text = st.text_area("Enter text (max 4096 characters)", "", max_chars=4096) |
| |
|
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | top_k_text = st.slider( |
| | "Top-k Text Evidences", min_value=1, max_value=5, value=2, key="top_k_text" |
| | ) |
| | with col2: |
| | top_k_image = st.slider( |
| | "Top-k Image Evidences", |
| | min_value=1, |
| | max_value=5, |
| | value=2, |
| | key="top_k_image", |
| | ) |
| |
|
| | |
| | if st.button("Verify Claim"): |
| | if not uploaded_image and not input_text: |
| | st.warning("Please upload an image or enter text.") |
| | return |
| |
|
| | progress = st.progress(0) |
| |
|
| | |
| | progress.progress(10) |
| | st.write("### Step 1: Generating caption...") |
| | image_caption = "" |
| | if uploaded_image: |
| | image_caption = generate_caption(image) |
| | st.write("**Generated Image Caption:**", image_caption) |
| |
|
| | |
| | progress.progress(40) |
| | st.write("### Step 2: Enriching text...") |
| | enriched_text = enrich_text_with_caption(input_text, image_caption) |
| | st.write("**Enriched Text:**") |
| | st.write(enriched_text) |
| |
|
| | |
| | progress.progress(50) |
| | st.write("### Step 3: Retrieving evidences by text...") |
| | if input_text: |
| | text_evidences = retrieve_evidences_by_text(enriched_text, top_k=top_k_text) |
| | st.write(f"Retrieved {len(text_evidences)} text evidences.") |
| | else: |
| | text_evidences = None |
| | st.write("Text modality is missing from the input claim!") |
| |
|
| | |
| | progress.progress(70) |
| | st.write("### Step 4: Retrieving evidences by image...") |
| | if uploaded_image: |
| | image_evidences = retrieve_evidences_by_image( |
| | uploaded_image, top_k=top_k_image |
| | ) |
| | st.write(f"Retrieved {len(image_evidences)} image evidences.") |
| | else: |
| | image_evidences = None |
| | st.write("Image modality is missing from the input claim!") |
| |
|
| | |
| | progress.progress(90) |
| | st.write("### Step 5: Verifying claim with retrieved evidences...") |
| | for evidence in (text_evidences or []) + (image_evidences or []): |
| | a, b, c, d = classify_evidence( |
| | claim_text=enriched_text, |
| | claim_image_path=uploaded_image, |
| | evidence_text=evidence.text, |
| | evidence_image_path=evidence.image_path, |
| | ) |
| | evidence.classification_result_all = a, b, c, d |
| | evidence.classification_result_final = get_final_classification( |
| | evidence.classification_result_all |
| | ) |
| |
|
| | |
| | progress.progress(100) |
| | if text_evidences or image_evidences: |
| | st.write("## Results") |
| | tabs = st.tabs(["Text Evidences", "Image Evidences"]) |
| |
|
| | with tabs[0]: |
| | if text_evidences: |
| | st.write("### Text Evidences") |
| | display_evidence_tab(text_evidences, "text") |
| | else: |
| | st.write("Text modality is missing from the input claim!") |
| |
|
| | with tabs[1]: |
| | if image_evidences: |
| | st.write("### Image Evidences") |
| | display_evidence_tab(image_evidences, "image") |
| | else: |
| | st.write("Image modality is missing from the input claim!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|