Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Ensure repo root is on path when running directly | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402 | |
| from shared.utils import print_banner, load_yaml_config # noqa: E402 | |
| from legaldoc_summarizer.dataset_loader import load_legal_dataset # noqa: E402 | |
| def _resolve_model_id(cfg): | |
| finetuned = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID") | |
| local_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer")) | |
| if finetuned: | |
| return finetuned | |
| if local_dir.exists(): | |
| return str(local_dir) | |
| return cfg["base_model"] | |
| def _build_hf_kwargs(token: str | None) -> dict: | |
| if not token: | |
| return {} | |
| return {"token": token} | |
| def _fallback_hf_kwargs(token: str | None) -> dict: | |
| if not token: | |
| return {} | |
| return {"use_auth_token": token} | |
| def evaluate_model(): | |
| print_banner("Evaluating LegalDoc Summarizer") | |
| cfg = load_yaml_config(Path(__file__).resolve().parent / "config.yaml") | |
| model_id = _resolve_model_id(cfg) | |
| auth_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") | |
| kwargs = _build_hf_kwargs(auth_token) | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs) | |
| except TypeError: | |
| fallback_kwargs = _fallback_hf_kwargs(auth_token) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **fallback_kwargs) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **fallback_kwargs) | |
| dataset_path = ROOT / "datasets/legal_sample.jsonl" | |
| if not dataset_path.exists(): | |
| print_banner("Dataset not found. Creating sample dataset...") | |
| load_legal_dataset() | |
| dataset = load_dataset("json", data_files=str(dataset_path), split="train[:100]") | |
| preds, refs = [], [] | |
| for row in dataset: | |
| inputs = tokenizer(row["question"], return_tensors="pt", truncation=True) | |
| output = model.generate(**inputs, max_new_tokens=256) | |
| preds.append(tokenizer.decode(output[0], skip_special_tokens=True)) | |
| refs.append(row["answer"]) | |
| results = {} | |
| results.update(compute_rouge(preds, refs)) | |
| results.update(compute_bleu(preds, refs)) | |
| results.update(factuality_score(preds, refs)) | |
| results_dir = ROOT / "models/legaldoc_summarizer" | |
| results_dir.mkdir(parents=True, exist_ok=True) | |
| with open(results_dir / "eval_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("β Evaluation complete:", results) | |
| if __name__ == "__main__": | |
| evaluate_model() | |