Spaces:
Sleeping
Sleeping
deepsodha-T
commited on
Commit
Β·
c8c411e
1
Parent(s):
5a77c4d
Fixed cutom model in legel doc
Browse files- datasets/legal_sample.jsonl +3 -0
- datasets/retail_sample.jsonl +0 -0
- legaldoc_summarizer/__pycache__/dataset_loader.cpython-311.pyc +0 -0
- legaldoc_summarizer/__pycache__/evaluate.cpython-311.pyc +0 -0
- legaldoc_summarizer/app.py +21 -4
- legaldoc_summarizer/config.yaml +2 -0
- legaldoc_summarizer/dataset_loader.py +28 -8
- legaldoc_summarizer/evaluate.py +57 -7
- models/legaldoc_summarizer/eval_results.json +82 -0
- requirements.txt +6 -0
- retailgpt_evaluator/__pycache__/dataset_loader.cpython-311.pyc +0 -0
- retailgpt_evaluator/__pycache__/evaluate.cpython-311.pyc +0 -0
- retailgpt_evaluator/app.py +24 -4
- retailgpt_evaluator/evaluate.py +19 -5
- shared/__init__.py +1 -0
- shared/__pycache__/__init__.cpython-311.pyc +0 -0
- shared/__pycache__/metrics.cpython-311.pyc +0 -0
- shared/__pycache__/utils.cpython-311.pyc +0 -0
- shared/hf_helpers.py +45 -5
- shared/metrics.py +4 -2
datasets/legal_sample.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"question":"Summarize the confidentiality clause: The parties agree to keep all proprietary information confidential for five years.","answer":"Both parties must keep proprietary info secret for five years."}
|
| 2 |
+
{"question":"Summarize the termination clause: Either party may terminate with 30 days written notice without cause.","answer":"Either side can end the agreement with 30 days written notice."}
|
| 3 |
+
{"question":"Summarize the liability clause: Liability is limited to direct damages not exceeding fees paid in the last 12 months.","answer":"Each party's liability is capped to direct damages up to fees from the past year."}
|
datasets/retail_sample.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
legaldoc_summarizer/__pycache__/dataset_loader.cpython-311.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
legaldoc_summarizer/__pycache__/evaluate.cpython-311.pyc
ADDED
|
Binary file (3.43 kB). View file
|
|
|
legaldoc_summarizer/app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from shared.hf_helpers import build_pipeline
|
| 3 |
import yaml
|
|
@@ -10,13 +11,29 @@ def main():
|
|
| 10 |
with open(CONFIG_PATH) as f:
|
| 11 |
cfg = yaml.safe_load(f)
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@st.cache_resource
|
| 16 |
-
def get_pipeline(model_name):
|
| 17 |
-
return build_pipeline(model_name)
|
| 18 |
|
| 19 |
-
pipe = get_pipeline(model_name)
|
| 20 |
|
| 21 |
st.write("Paste a contract clause or judgment text below:")
|
| 22 |
text = st.text_area("Clause or Legal Text", height=250)
|
|
|
|
| 1 |
+
import os
|
| 2 |
import streamlit as st
|
| 3 |
from shared.hf_helpers import build_pipeline
|
| 4 |
import yaml
|
|
|
|
| 11 |
with open(CONFIG_PATH) as f:
|
| 12 |
cfg = yaml.safe_load(f)
|
| 13 |
|
| 14 |
+
base_model = cfg["base_model"]
|
| 15 |
+
finetuned_model = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
|
| 16 |
+
local_model_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
|
| 17 |
+
|
| 18 |
+
model_options = [base_model]
|
| 19 |
+
if finetuned_model:
|
| 20 |
+
model_options.append(finetuned_model)
|
| 21 |
+
elif local_model_dir.exists():
|
| 22 |
+
model_options.append(str(local_model_dir))
|
| 23 |
+
else:
|
| 24 |
+
st.info(
|
| 25 |
+
"Using the base model until a fine-tuned checkpoint is available. "
|
| 26 |
+
"Train a model to populate `models/legaldoc_summarizer` or set `LEGALDOC_MODEL_ID` / `finetuned_model`."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
model_name = st.selectbox("Model:", model_options)
|
| 30 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
|
| 31 |
|
| 32 |
@st.cache_resource
|
| 33 |
+
def get_pipeline(model_name, token):
|
| 34 |
+
return build_pipeline(model_name, token=token)
|
| 35 |
|
| 36 |
+
pipe = get_pipeline(model_name, hf_token)
|
| 37 |
|
| 38 |
st.write("Paste a contract clause or judgment text below:")
|
| 39 |
text = st.text_area("Clause or Legal Text", height=250)
|
legaldoc_summarizer/config.yaml
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
project: "LegalDocSummarizer"
|
| 2 |
base_model: "google/flan-t5-base"
|
|
|
|
|
|
|
| 3 |
dataset_name: "cuad" # Contract Understanding Atticus Dataset
|
| 4 |
train:
|
| 5 |
epochs: 3
|
|
|
|
| 1 |
project: "LegalDocSummarizer"
|
| 2 |
base_model: "google/flan-t5-base"
|
| 3 |
+
finetuned_model: "" # Optional: HF repo ID for a private/public fine-tuned model
|
| 4 |
+
finetuned_local_dir: "models/legaldoc_summarizer"
|
| 5 |
dataset_name: "cuad" # Contract Understanding Atticus Dataset
|
| 6 |
train:
|
| 7 |
epochs: 3
|
legaldoc_summarizer/dataset_loader.py
CHANGED
|
@@ -1,18 +1,38 @@
|
|
|
|
|
|
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
-
|
| 3 |
|
| 4 |
def load_legal_dataset():
|
| 5 |
"""
|
| 6 |
Loads a small portion of the CUAD dataset (contract clauses).
|
| 7 |
-
|
| 8 |
"""
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
data = df[["question_text", "answer"]].rename(columns={"question_text": "question"})
|
| 16 |
os.makedirs("datasets", exist_ok=True)
|
| 17 |
data.to_json("datasets/legal_sample.jsonl", orient="records", lines=True)
|
| 18 |
print("β
Saved sample dataset to datasets/legal_sample.jsonl")
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
|
| 6 |
def load_legal_dataset():
|
| 7 |
"""
|
| 8 |
Loads a small portion of the CUAD dataset (contract clauses).
|
| 9 |
+
Falls back to a tiny synthetic sample if the dataset is unavailable (e.g., offline).
|
| 10 |
"""
|
| 11 |
+
try:
|
| 12 |
+
dataset = load_dataset("cuad", "cuad_v1", split="train[:200]")
|
| 13 |
+
df = pd.DataFrame(dataset)
|
| 14 |
+
df["question_text"] = "Summarize the key legal clause: " + df["question_text"]
|
| 15 |
+
df["answer"] = df["answers"].apply(lambda a: a[0]["text"][0] if a and a[0]["text"] else "")
|
| 16 |
+
data = df[["question_text", "answer"]].rename(columns={"question_text": "question"})
|
| 17 |
+
except Exception as exc: # pragma: no cover - offline/sandbox fallback
|
| 18 |
+
print(f"β οΈ Unable to load CUAD from Hub ({exc}). Using synthetic sample.")
|
| 19 |
+
data = pd.DataFrame(
|
| 20 |
+
[
|
| 21 |
+
{
|
| 22 |
+
"question": "Summarize the confidentiality clause: The parties agree to keep all proprietary information confidential for five years.",
|
| 23 |
+
"answer": "Both parties must keep proprietary info secret for five years.",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"question": "Summarize the termination clause: Either party may terminate with 30 days written notice without cause.",
|
| 27 |
+
"answer": "Either side can end the agreement with 30 days written notice.",
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"question": "Summarize the liability clause: Liability is limited to direct damages not exceeding fees paid in the last 12 months.",
|
| 31 |
+
"answer": "Each party's liability is capped to direct damages up to fees from the past year.",
|
| 32 |
+
},
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
|
|
|
|
| 36 |
os.makedirs("datasets", exist_ok=True)
|
| 37 |
data.to_json("datasets/legal_sample.jsonl", orient="records", lines=True)
|
| 38 |
print("β
Saved sample dataset to datasets/legal_sample.jsonl")
|
legaldoc_summarizer/evaluate.py
CHANGED
|
@@ -1,16 +1,63 @@
|
|
| 1 |
import json
|
|
|
|
|
|
|
|
|
|
| 2 |
from datasets import load_dataset
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
-
from shared.metrics import compute_rouge, compute_bleu, factuality_score
|
| 5 |
-
from shared.utils import print_banner
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
print_banner("Evaluating LegalDoc Summarizer")
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
preds, refs = [], []
|
| 16 |
for row in dataset:
|
|
@@ -24,9 +71,12 @@ def evaluate_model(model_path="models/legaldoc_summarizer"):
|
|
| 24 |
results.update(compute_bleu(preds, refs))
|
| 25 |
results.update(factuality_score(preds, refs))
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
json.dump(results, f, indent=2)
|
| 29 |
print("β
Evaluation complete:", results)
|
| 30 |
|
|
|
|
| 31 |
if __name__ == "__main__":
|
| 32 |
evaluate_model()
|
|
|
|
| 1 |
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
from datasets import load_dataset
|
| 6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# Ensure repo root is on path when running directly
|
| 9 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 10 |
+
if str(ROOT) not in sys.path:
|
| 11 |
+
sys.path.insert(0, str(ROOT))
|
| 12 |
+
|
| 13 |
+
from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402
|
| 14 |
+
from shared.utils import print_banner, load_yaml_config # noqa: E402
|
| 15 |
+
from legaldoc_summarizer.dataset_loader import load_legal_dataset # noqa: E402
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _resolve_model_id(cfg):
|
| 19 |
+
finetuned = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
|
| 20 |
+
local_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
|
| 21 |
+
if finetuned:
|
| 22 |
+
return finetuned
|
| 23 |
+
if local_dir.exists():
|
| 24 |
+
return str(local_dir)
|
| 25 |
+
return cfg["base_model"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _build_hf_kwargs(token: str | None) -> dict:
|
| 29 |
+
if not token:
|
| 30 |
+
return {}
|
| 31 |
+
return {"token": token}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _fallback_hf_kwargs(token: str | None) -> dict:
|
| 35 |
+
if not token:
|
| 36 |
+
return {}
|
| 37 |
+
return {"use_auth_token": token}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def evaluate_model():
|
| 41 |
print_banner("Evaluating LegalDoc Summarizer")
|
| 42 |
|
| 43 |
+
cfg = load_yaml_config(Path(__file__).resolve().parent / "config.yaml")
|
| 44 |
+
model_id = _resolve_model_id(cfg)
|
| 45 |
+
auth_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
|
| 46 |
|
| 47 |
+
kwargs = _build_hf_kwargs(auth_token)
|
| 48 |
+
try:
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
|
| 50 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs)
|
| 51 |
+
except TypeError:
|
| 52 |
+
fallback_kwargs = _fallback_hf_kwargs(auth_token)
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, **fallback_kwargs)
|
| 54 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **fallback_kwargs)
|
| 55 |
+
|
| 56 |
+
dataset_path = ROOT / "datasets/legal_sample.jsonl"
|
| 57 |
+
if not dataset_path.exists():
|
| 58 |
+
print_banner("Dataset not found. Creating sample dataset...")
|
| 59 |
+
load_legal_dataset()
|
| 60 |
+
dataset = load_dataset("json", data_files=str(dataset_path), split="train[:100]")
|
| 61 |
|
| 62 |
preds, refs = [], []
|
| 63 |
for row in dataset:
|
|
|
|
| 71 |
results.update(compute_bleu(preds, refs))
|
| 72 |
results.update(factuality_score(preds, refs))
|
| 73 |
|
| 74 |
+
results_dir = ROOT / "models/legaldoc_summarizer"
|
| 75 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
with open(results_dir / "eval_results.json", "w") as f:
|
| 77 |
json.dump(results, f, indent=2)
|
| 78 |
print("β
Evaluation complete:", results)
|
| 79 |
|
| 80 |
+
|
| 81 |
if __name__ == "__main__":
|
| 82 |
evaluate_model()
|
models/legaldoc_summarizer/eval_results.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"rouge1": [
|
| 3 |
+
[
|
| 4 |
+
0.14285714285714285,
|
| 5 |
+
0.0909090909090909,
|
| 6 |
+
0.1111111111111111
|
| 7 |
+
],
|
| 8 |
+
[
|
| 9 |
+
0.30952380952380953,
|
| 10 |
+
0.31363636363636366,
|
| 11 |
+
0.3077441077441077
|
| 12 |
+
],
|
| 13 |
+
[
|
| 14 |
+
0.5,
|
| 15 |
+
0.6,
|
| 16 |
+
0.5454545454545454
|
| 17 |
+
]
|
| 18 |
+
],
|
| 19 |
+
"rouge2": [
|
| 20 |
+
[
|
| 21 |
+
0.0,
|
| 22 |
+
0.0,
|
| 23 |
+
0.0
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
0.06060606060606061,
|
| 27 |
+
0.07407407407407407,
|
| 28 |
+
0.06666666666666667
|
| 29 |
+
],
|
| 30 |
+
[
|
| 31 |
+
0.1818181818181818,
|
| 32 |
+
0.2222222222222222,
|
| 33 |
+
0.19999999999999998
|
| 34 |
+
]
|
| 35 |
+
],
|
| 36 |
+
"rougeL": [
|
| 37 |
+
[
|
| 38 |
+
0.14285714285714285,
|
| 39 |
+
0.0909090909090909,
|
| 40 |
+
0.1111111111111111
|
| 41 |
+
],
|
| 42 |
+
[
|
| 43 |
+
0.2857142857142857,
|
| 44 |
+
0.2928030303030303,
|
| 45 |
+
0.2855218855218855
|
| 46 |
+
],
|
| 47 |
+
[
|
| 48 |
+
0.5,
|
| 49 |
+
0.6,
|
| 50 |
+
0.5454545454545454
|
| 51 |
+
]
|
| 52 |
+
],
|
| 53 |
+
"rougeLsum": [
|
| 54 |
+
[
|
| 55 |
+
0.14285714285714285,
|
| 56 |
+
0.0909090909090909,
|
| 57 |
+
0.1111111111111111
|
| 58 |
+
],
|
| 59 |
+
[
|
| 60 |
+
0.2857142857142857,
|
| 61 |
+
0.2928030303030303,
|
| 62 |
+
0.2855218855218855
|
| 63 |
+
],
|
| 64 |
+
[
|
| 65 |
+
0.5,
|
| 66 |
+
0.6,
|
| 67 |
+
0.5454545454545454
|
| 68 |
+
]
|
| 69 |
+
],
|
| 70 |
+
"bleu": 0.0,
|
| 71 |
+
"precisions": [
|
| 72 |
+
0.3333333333333333,
|
| 73 |
+
0.06666666666666667,
|
| 74 |
+
0.037037037037037035,
|
| 75 |
+
0.0
|
| 76 |
+
],
|
| 77 |
+
"brevity_penalty": 0.9131007162822622,
|
| 78 |
+
"length_ratio": 0.9166666666666666,
|
| 79 |
+
"translation_length": 33,
|
| 80 |
+
"reference_length": 36,
|
| 81 |
+
"factuality": 0.3255411255411255
|
| 82 |
+
}
|
requirements.txt
CHANGED
|
@@ -8,3 +8,9 @@ gradio==4.44.0
|
|
| 8 |
tokenizers==0.19.1
|
| 9 |
# Explicitly pin pydantic major to avoid breaking gradio deps
|
| 10 |
pydantic==2.7.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
tokenizers==0.19.1
|
| 9 |
# Explicitly pin pydantic major to avoid breaking gradio deps
|
| 10 |
pydantic==2.7.4
|
| 11 |
+
# Evaluation and datasets
|
| 12 |
+
datasets==2.21.0
|
| 13 |
+
evaluate==0.4.2
|
| 14 |
+
rouge-score==0.1.2
|
| 15 |
+
nltk==3.9.1
|
| 16 |
+
absl-py==2.1.0
|
retailgpt_evaluator/__pycache__/dataset_loader.cpython-311.pyc
ADDED
|
Binary file (1.34 kB). View file
|
|
|
retailgpt_evaluator/__pycache__/evaluate.cpython-311.pyc
ADDED
|
Binary file (3.39 kB). View file
|
|
|
retailgpt_evaluator/app.py
CHANGED
|
@@ -19,10 +19,30 @@ def main():
|
|
| 19 |
cfg = yaml.safe_load(f)
|
| 20 |
|
| 21 |
# Show leaderboard if exists
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
else:
|
| 27 |
st.warning("Run `evaluate.py` first to generate metrics.")
|
| 28 |
|
|
|
|
| 19 |
cfg = yaml.safe_load(f)
|
| 20 |
|
| 21 |
# Show leaderboard if exists
|
| 22 |
+
leaderboard_df = None
|
| 23 |
+
results_path = Path("models/retail_eval_results.json")
|
| 24 |
+
if results_path.exists():
|
| 25 |
+
try:
|
| 26 |
+
leaderboard_df = build_leaderboard(results_path)
|
| 27 |
+
st.subheader("π Model Leaderboard")
|
| 28 |
+
st.dataframe(leaderboard_df, use_container_width=True)
|
| 29 |
+
|
| 30 |
+
st.markdown("#### π Evaluation Metrics")
|
| 31 |
+
metric_options = leaderboard_df["model"].tolist()
|
| 32 |
+
selected = st.selectbox("Inspect metrics for:", metric_options)
|
| 33 |
+
selected_row = leaderboard_df[leaderboard_df["model"] == selected].iloc[0]
|
| 34 |
+
|
| 35 |
+
cols = st.columns(4)
|
| 36 |
+
cols[0].metric("ROUGE-L", f"{selected_row['rougeL']:.3f}")
|
| 37 |
+
cols[1].metric("BLEU", f"{selected_row['bleu']:.3f}")
|
| 38 |
+
cols[2].metric("Factuality", f"{selected_row['factuality']:.3f}")
|
| 39 |
+
cols[3].metric("Score (avg)", f"{selected_row['score']:.3f}")
|
| 40 |
+
|
| 41 |
+
st.bar_chart(
|
| 42 |
+
leaderboard_df.set_index("model")[["rougeL", "bleu", "factuality", "score"]]
|
| 43 |
+
)
|
| 44 |
+
except Exception as exc: # pragma: no cover - defensive UI fallback
|
| 45 |
+
st.error(f"Unable to load evaluation results: {exc}")
|
| 46 |
else:
|
| 47 |
st.warning("Run `evaluate.py` first to generate metrics.")
|
| 48 |
|
retailgpt_evaluator/evaluate.py
CHANGED
|
@@ -1,10 +1,19 @@
|
|
| 1 |
import json
|
|
|
|
|
|
|
| 2 |
from datasets import load_dataset
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
-
from shared.metrics import compute_rouge, compute_bleu, factuality_score
|
| 5 |
-
from shared.utils import print_banner
|
| 6 |
import torch
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def run_eval_for_model(model_name, dataset):
|
| 9 |
print_banner(f"Evaluating {model_name}")
|
| 10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
@@ -22,9 +31,14 @@ def run_eval_for_model(model_name, dataset):
|
|
| 22 |
return {"model": model_name, **r, **b, **f}
|
| 23 |
|
| 24 |
def evaluate_all():
|
| 25 |
-
|
| 26 |
-
cfg = load_yaml_config(
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
results = [run_eval_for_model(m, dataset) for m in cfg["models"]]
|
| 29 |
json.dump(results, open("models/retail_eval_results.json", "w"), indent=2)
|
| 30 |
print("β
Saved results to models/retail_eval_results.json")
|
|
|
|
| 1 |
import json
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
from datasets import load_dataset
|
| 5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
# Ensure repo root is on the path so `shared` package is found when run directly
|
| 9 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 10 |
+
if str(ROOT) not in sys.path:
|
| 11 |
+
sys.path.insert(0, str(ROOT))
|
| 12 |
+
|
| 13 |
+
from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402
|
| 14 |
+
from shared.utils import print_banner, load_yaml_config # noqa: E402
|
| 15 |
+
from retailgpt_evaluator.dataset_loader import load_retail_dataset # noqa: E402
|
| 16 |
+
|
| 17 |
def run_eval_for_model(model_name, dataset):
|
| 18 |
print_banner(f"Evaluating {model_name}")
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 31 |
return {"model": model_name, **r, **b, **f}
|
| 32 |
|
| 33 |
def evaluate_all():
|
| 34 |
+
config_path = Path(__file__).resolve().parent / "config.yaml"
|
| 35 |
+
cfg = load_yaml_config(config_path)
|
| 36 |
+
dataset_path = ROOT / "datasets/retail_sample.jsonl"
|
| 37 |
+
if not dataset_path.exists():
|
| 38 |
+
print_banner("Dataset not found. Creating sample dataset...")
|
| 39 |
+
load_retail_dataset()
|
| 40 |
+
dataset = load_dataset("json", data_files=str(dataset_path), split="train[:50]")
|
| 41 |
+
(ROOT / "models").mkdir(exist_ok=True)
|
| 42 |
results = [run_eval_for_model(m, dataset) for m in cfg["models"]]
|
| 43 |
json.dump(results, open("models/retail_eval_results.json", "w"), indent=2)
|
| 44 |
print("β
Saved results to models/retail_eval_results.json")
|
shared/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Makes shared a package so imports like `from shared import ...` work.
|
shared/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
shared/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (2.27 kB). View file
|
|
|
shared/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
shared/hf_helpers.py
CHANGED
|
@@ -1,10 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 2 |
import torch
|
| 3 |
|
| 4 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""Load a model and tokenizer for inference."""
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
return model, tokenizer
|
| 9 |
|
| 10 |
def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
|
|
@@ -14,6 +46,14 @@ def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
|
|
| 14 |
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
|
| 15 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 16 |
|
| 17 |
-
def build_pipeline(model_name: str, task="text2text-generation"):
|
| 18 |
"""Return a Hugging Face pipeline for inference."""
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 5 |
import torch
|
| 6 |
|
| 7 |
+
def _resolve_model_identifier(model_name: str) -> str:
|
| 8 |
+
"""Return a valid model identifier or local path."""
|
| 9 |
+
path_candidate = Path(model_name)
|
| 10 |
+
if path_candidate.exists():
|
| 11 |
+
return str(path_candidate)
|
| 12 |
+
return model_name
|
| 13 |
+
|
| 14 |
+
def _build_hub_kwargs(token: Optional[str]) -> dict:
|
| 15 |
+
"""Prepare kwargs for Hugging Face Hub auth across library versions."""
|
| 16 |
+
if not token:
|
| 17 |
+
return {}
|
| 18 |
+
return {"token": token}
|
| 19 |
+
|
| 20 |
+
def _fallback_hub_kwargs(token: Optional[str]) -> dict:
|
| 21 |
+
"""Older transformers versions still expect use_auth_token."""
|
| 22 |
+
if not token:
|
| 23 |
+
return {}
|
| 24 |
+
return {"use_auth_token": token}
|
| 25 |
+
|
| 26 |
+
def load_model_and_tokenizer(model_name: str, token: Optional[str] = None):
|
| 27 |
"""Load a model and tokenizer for inference."""
|
| 28 |
+
resolved_model = _resolve_model_identifier(model_name)
|
| 29 |
+
auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
|
| 30 |
+
|
| 31 |
+
kwargs = _build_hub_kwargs(auth_token)
|
| 32 |
+
try:
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(resolved_model, **kwargs)
|
| 34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **kwargs)
|
| 35 |
+
except TypeError:
|
| 36 |
+
fallback_kwargs = _fallback_hub_kwargs(auth_token)
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(resolved_model, **fallback_kwargs)
|
| 38 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **fallback_kwargs)
|
| 39 |
+
|
| 40 |
return model, tokenizer
|
| 41 |
|
| 42 |
def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
|
|
|
|
| 46 |
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
|
| 47 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 48 |
|
| 49 |
+
def build_pipeline(model_name: str, task="text2text-generation", token: Optional[str] = None):
|
| 50 |
"""Return a Hugging Face pipeline for inference."""
|
| 51 |
+
resolved_model = _resolve_model_identifier(model_name)
|
| 52 |
+
auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
|
| 53 |
+
|
| 54 |
+
kwargs = _build_hub_kwargs(auth_token)
|
| 55 |
+
try:
|
| 56 |
+
return pipeline(task, model=resolved_model, **kwargs)
|
| 57 |
+
except TypeError:
|
| 58 |
+
fallback_kwargs = _fallback_hub_kwargs(auth_token)
|
| 59 |
+
return pipeline(task, model=resolved_model, **fallback_kwargs)
|
shared/metrics.py
CHANGED
|
@@ -6,9 +6,11 @@ def compute_rouge(preds, refs):
|
|
| 6 |
return rouge.compute(predictions=preds, references=refs)
|
| 7 |
|
| 8 |
def compute_bleu(preds, refs):
|
|
|
|
| 9 |
bleu = load_metric("bleu")
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
def factuality_score(preds, refs):
|
| 14 |
"""Very simple lexical overlap metric for factual alignment."""
|
|
|
|
| 6 |
return rouge.compute(predictions=preds, references=refs)
|
| 7 |
|
| 8 |
def compute_bleu(preds, refs):
|
| 9 |
+
"""BLEU with simple whitespace tokenization for compatibility."""
|
| 10 |
bleu = load_metric("bleu")
|
| 11 |
+
pred_tokens = [p.split() for p in preds]
|
| 12 |
+
ref_tokens = [[r.split()] for r in refs] # bleu expects list of list of token lists
|
| 13 |
+
return bleu.compute(predictions=pred_tokens, references=ref_tokens)
|
| 14 |
|
| 15 |
def factuality_score(preds, refs):
|
| 16 |
"""Very simple lexical overlap metric for factual alignment."""
|