import json import os import sys from pathlib import Path from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Ensure repo root is on path when running directly ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402 from shared.utils import print_banner, load_yaml_config # noqa: E402 from legaldoc_summarizer.dataset_loader import load_legal_dataset # noqa: E402 def _resolve_model_id(cfg): finetuned = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID") local_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer")) if finetuned: return finetuned if local_dir.exists(): return str(local_dir) return cfg["base_model"] def _build_hf_kwargs(token: str | None) -> dict: if not token: return {} return {"token": token} def _fallback_hf_kwargs(token: str | None) -> dict: if not token: return {} return {"use_auth_token": token} def evaluate_model(): print_banner("Evaluating LegalDoc Summarizer") cfg = load_yaml_config(Path(__file__).resolve().parent / "config.yaml") model_id = _resolve_model_id(cfg) auth_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") kwargs = _build_hf_kwargs(auth_token) try: tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs) except TypeError: fallback_kwargs = _fallback_hf_kwargs(auth_token) tokenizer = AutoTokenizer.from_pretrained(model_id, **fallback_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **fallback_kwargs) dataset_path = ROOT / "datasets/legal_sample.jsonl" if not dataset_path.exists(): print_banner("Dataset not found. Creating sample dataset...") load_legal_dataset() dataset = load_dataset("json", data_files=str(dataset_path), split="train[:100]") preds, refs = [], [] for row in dataset: inputs = tokenizer(row["question"], return_tensors="pt", truncation=True) output = model.generate(**inputs, max_new_tokens=256) preds.append(tokenizer.decode(output[0], skip_special_tokens=True)) refs.append(row["answer"]) results = {} results.update(compute_rouge(preds, refs)) results.update(compute_bleu(preds, refs)) results.update(factuality_score(preds, refs)) results_dir = ROOT / "models/legaldoc_summarizer" results_dir.mkdir(parents=True, exist_ok=True) with open(results_dir / "eval_results.json", "w") as f: json.dump(results, f, indent=2) print("✅ Evaluation complete:", results) if __name__ == "__main__": evaluate_model()