deepsodha-T
Fixed cutom model in legel doc
c8c411e
raw
history blame
2.87 kB
import json
import os
import sys
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Ensure repo root is on path when running directly
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402
from shared.utils import print_banner, load_yaml_config # noqa: E402
from legaldoc_summarizer.dataset_loader import load_legal_dataset # noqa: E402
def _resolve_model_id(cfg):
finetuned = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
local_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
if finetuned:
return finetuned
if local_dir.exists():
return str(local_dir)
return cfg["base_model"]
def _build_hf_kwargs(token: str | None) -> dict:
if not token:
return {}
return {"token": token}
def _fallback_hf_kwargs(token: str | None) -> dict:
if not token:
return {}
return {"use_auth_token": token}
def evaluate_model():
print_banner("Evaluating LegalDoc Summarizer")
cfg = load_yaml_config(Path(__file__).resolve().parent / "config.yaml")
model_id = _resolve_model_id(cfg)
auth_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
kwargs = _build_hf_kwargs(auth_token)
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs)
except TypeError:
fallback_kwargs = _fallback_hf_kwargs(auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, **fallback_kwargs)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **fallback_kwargs)
dataset_path = ROOT / "datasets/legal_sample.jsonl"
if not dataset_path.exists():
print_banner("Dataset not found. Creating sample dataset...")
load_legal_dataset()
dataset = load_dataset("json", data_files=str(dataset_path), split="train[:100]")
preds, refs = [], []
for row in dataset:
inputs = tokenizer(row["question"], return_tensors="pt", truncation=True)
output = model.generate(**inputs, max_new_tokens=256)
preds.append(tokenizer.decode(output[0], skip_special_tokens=True))
refs.append(row["answer"])
results = {}
results.update(compute_rouge(preds, refs))
results.update(compute_bleu(preds, refs))
results.update(factuality_score(preds, refs))
results_dir = ROOT / "models/legaldoc_summarizer"
results_dir.mkdir(parents=True, exist_ok=True)
with open(results_dir / "eval_results.json", "w") as f:
json.dump(results, f, indent=2)
print("βœ… Evaluation complete:", results)
if __name__ == "__main__":
evaluate_model()