deepsodha-T commited on
Commit
c8c411e
Β·
1 Parent(s): 5a77c4d

Fixed cutom model in legel doc

Browse files
datasets/legal_sample.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"question":"Summarize the confidentiality clause: The parties agree to keep all proprietary information confidential for five years.","answer":"Both parties must keep proprietary info secret for five years."}
2
+ {"question":"Summarize the termination clause: Either party may terminate with 30 days written notice without cause.","answer":"Either side can end the agreement with 30 days written notice."}
3
+ {"question":"Summarize the liability clause: Liability is limited to direct damages not exceeding fees paid in the last 12 months.","answer":"Each party's liability is capped to direct damages up to fees from the past year."}
datasets/retail_sample.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
legaldoc_summarizer/__pycache__/dataset_loader.cpython-311.pyc ADDED
Binary file (2.79 kB). View file
 
legaldoc_summarizer/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (3.43 kB). View file
 
legaldoc_summarizer/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from shared.hf_helpers import build_pipeline
3
  import yaml
@@ -10,13 +11,29 @@ def main():
10
  with open(CONFIG_PATH) as f:
11
  cfg = yaml.safe_load(f)
12
 
13
- model_name = st.selectbox("Model:", [cfg["base_model"], "models/legaldoc_summarizer"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @st.cache_resource
16
- def get_pipeline(model_name):
17
- return build_pipeline(model_name)
18
 
19
- pipe = get_pipeline(model_name)
20
 
21
  st.write("Paste a contract clause or judgment text below:")
22
  text = st.text_area("Clause or Legal Text", height=250)
 
1
+ import os
2
  import streamlit as st
3
  from shared.hf_helpers import build_pipeline
4
  import yaml
 
11
  with open(CONFIG_PATH) as f:
12
  cfg = yaml.safe_load(f)
13
 
14
+ base_model = cfg["base_model"]
15
+ finetuned_model = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
16
+ local_model_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
17
+
18
+ model_options = [base_model]
19
+ if finetuned_model:
20
+ model_options.append(finetuned_model)
21
+ elif local_model_dir.exists():
22
+ model_options.append(str(local_model_dir))
23
+ else:
24
+ st.info(
25
+ "Using the base model until a fine-tuned checkpoint is available. "
26
+ "Train a model to populate `models/legaldoc_summarizer` or set `LEGALDOC_MODEL_ID` / `finetuned_model`."
27
+ )
28
+
29
+ model_name = st.selectbox("Model:", model_options)
30
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
31
 
32
  @st.cache_resource
33
+ def get_pipeline(model_name, token):
34
+ return build_pipeline(model_name, token=token)
35
 
36
+ pipe = get_pipeline(model_name, hf_token)
37
 
38
  st.write("Paste a contract clause or judgment text below:")
39
  text = st.text_area("Clause or Legal Text", height=250)
legaldoc_summarizer/config.yaml CHANGED
@@ -1,5 +1,7 @@
1
  project: "LegalDocSummarizer"
2
  base_model: "google/flan-t5-base"
 
 
3
  dataset_name: "cuad" # Contract Understanding Atticus Dataset
4
  train:
5
  epochs: 3
 
1
  project: "LegalDocSummarizer"
2
  base_model: "google/flan-t5-base"
3
+ finetuned_model: "" # Optional: HF repo ID for a private/public fine-tuned model
4
+ finetuned_local_dir: "models/legaldoc_summarizer"
5
  dataset_name: "cuad" # Contract Understanding Atticus Dataset
6
  train:
7
  epochs: 3
legaldoc_summarizer/dataset_loader.py CHANGED
@@ -1,18 +1,38 @@
 
 
1
  from datasets import load_dataset
2
- import pandas as pd, os
3
 
4
  def load_legal_dataset():
5
  """
6
  Loads a small portion of the CUAD dataset (contract clauses).
7
- Converts each clause into (document_text, summary) pairs.
8
  """
9
- dataset = load_dataset("cuad", "cuad_v1", split="train[:200]")
10
- df = pd.DataFrame(dataset)
11
-
12
- df["question_text"] = "Summarize the key legal clause: " + df["question_text"]
13
- df["answer"] = df["answers"].apply(lambda a: a[0]["text"][0] if a and a[0]["text"] else "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- data = df[["question_text", "answer"]].rename(columns={"question_text": "question"})
16
  os.makedirs("datasets", exist_ok=True)
17
  data.to_json("datasets/legal_sample.jsonl", orient="records", lines=True)
18
  print("βœ… Saved sample dataset to datasets/legal_sample.jsonl")
 
1
+ import os
2
+ import pandas as pd
3
  from datasets import load_dataset
4
+
5
 
6
  def load_legal_dataset():
7
  """
8
  Loads a small portion of the CUAD dataset (contract clauses).
9
+ Falls back to a tiny synthetic sample if the dataset is unavailable (e.g., offline).
10
  """
11
+ try:
12
+ dataset = load_dataset("cuad", "cuad_v1", split="train[:200]")
13
+ df = pd.DataFrame(dataset)
14
+ df["question_text"] = "Summarize the key legal clause: " + df["question_text"]
15
+ df["answer"] = df["answers"].apply(lambda a: a[0]["text"][0] if a and a[0]["text"] else "")
16
+ data = df[["question_text", "answer"]].rename(columns={"question_text": "question"})
17
+ except Exception as exc: # pragma: no cover - offline/sandbox fallback
18
+ print(f"⚠️ Unable to load CUAD from Hub ({exc}). Using synthetic sample.")
19
+ data = pd.DataFrame(
20
+ [
21
+ {
22
+ "question": "Summarize the confidentiality clause: The parties agree to keep all proprietary information confidential for five years.",
23
+ "answer": "Both parties must keep proprietary info secret for five years.",
24
+ },
25
+ {
26
+ "question": "Summarize the termination clause: Either party may terminate with 30 days written notice without cause.",
27
+ "answer": "Either side can end the agreement with 30 days written notice.",
28
+ },
29
+ {
30
+ "question": "Summarize the liability clause: Liability is limited to direct damages not exceeding fees paid in the last 12 months.",
31
+ "answer": "Each party's liability is capped to direct damages up to fees from the past year.",
32
+ },
33
+ ]
34
+ )
35
 
 
36
  os.makedirs("datasets", exist_ok=True)
37
  data.to_json("datasets/legal_sample.jsonl", orient="records", lines=True)
38
  print("βœ… Saved sample dataset to datasets/legal_sample.jsonl")
legaldoc_summarizer/evaluate.py CHANGED
@@ -1,16 +1,63 @@
1
  import json
 
 
 
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from shared.metrics import compute_rouge, compute_bleu, factuality_score
5
- from shared.utils import print_banner
6
 
7
- def evaluate_model(model_path="models/legaldoc_summarizer"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  print_banner("Evaluating LegalDoc Summarizer")
9
 
10
- tokenizer = AutoTokenizer.from_pretrained(model_path)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
 
12
 
13
- dataset = load_dataset("json", data_files="datasets/legal_sample.jsonl", split="train[:100]")
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  preds, refs = [], []
16
  for row in dataset:
@@ -24,9 +71,12 @@ def evaluate_model(model_path="models/legaldoc_summarizer"):
24
  results.update(compute_bleu(preds, refs))
25
  results.update(factuality_score(preds, refs))
26
 
27
- with open("models/legaldoc_summarizer/eval_results.json", "w") as f:
 
 
28
  json.dump(results, f, indent=2)
29
  print("βœ… Evaluation complete:", results)
30
 
 
31
  if __name__ == "__main__":
32
  evaluate_model()
 
1
  import json
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
  from datasets import load_dataset
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
7
 
8
+ # Ensure repo root is on path when running directly
9
+ ROOT = Path(__file__).resolve().parents[1]
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.insert(0, str(ROOT))
12
+
13
+ from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402
14
+ from shared.utils import print_banner, load_yaml_config # noqa: E402
15
+ from legaldoc_summarizer.dataset_loader import load_legal_dataset # noqa: E402
16
+
17
+
18
+ def _resolve_model_id(cfg):
19
+ finetuned = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
20
+ local_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
21
+ if finetuned:
22
+ return finetuned
23
+ if local_dir.exists():
24
+ return str(local_dir)
25
+ return cfg["base_model"]
26
+
27
+
28
+ def _build_hf_kwargs(token: str | None) -> dict:
29
+ if not token:
30
+ return {}
31
+ return {"token": token}
32
+
33
+
34
+ def _fallback_hf_kwargs(token: str | None) -> dict:
35
+ if not token:
36
+ return {}
37
+ return {"use_auth_token": token}
38
+
39
+
40
+ def evaluate_model():
41
  print_banner("Evaluating LegalDoc Summarizer")
42
 
43
+ cfg = load_yaml_config(Path(__file__).resolve().parent / "config.yaml")
44
+ model_id = _resolve_model_id(cfg)
45
+ auth_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
46
 
47
+ kwargs = _build_hf_kwargs(auth_token)
48
+ try:
49
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
50
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs)
51
+ except TypeError:
52
+ fallback_kwargs = _fallback_hf_kwargs(auth_token)
53
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **fallback_kwargs)
54
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **fallback_kwargs)
55
+
56
+ dataset_path = ROOT / "datasets/legal_sample.jsonl"
57
+ if not dataset_path.exists():
58
+ print_banner("Dataset not found. Creating sample dataset...")
59
+ load_legal_dataset()
60
+ dataset = load_dataset("json", data_files=str(dataset_path), split="train[:100]")
61
 
62
  preds, refs = [], []
63
  for row in dataset:
 
71
  results.update(compute_bleu(preds, refs))
72
  results.update(factuality_score(preds, refs))
73
 
74
+ results_dir = ROOT / "models/legaldoc_summarizer"
75
+ results_dir.mkdir(parents=True, exist_ok=True)
76
+ with open(results_dir / "eval_results.json", "w") as f:
77
  json.dump(results, f, indent=2)
78
  print("βœ… Evaluation complete:", results)
79
 
80
+
81
  if __name__ == "__main__":
82
  evaluate_model()
models/legaldoc_summarizer/eval_results.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "rouge1": [
3
+ [
4
+ 0.14285714285714285,
5
+ 0.0909090909090909,
6
+ 0.1111111111111111
7
+ ],
8
+ [
9
+ 0.30952380952380953,
10
+ 0.31363636363636366,
11
+ 0.3077441077441077
12
+ ],
13
+ [
14
+ 0.5,
15
+ 0.6,
16
+ 0.5454545454545454
17
+ ]
18
+ ],
19
+ "rouge2": [
20
+ [
21
+ 0.0,
22
+ 0.0,
23
+ 0.0
24
+ ],
25
+ [
26
+ 0.06060606060606061,
27
+ 0.07407407407407407,
28
+ 0.06666666666666667
29
+ ],
30
+ [
31
+ 0.1818181818181818,
32
+ 0.2222222222222222,
33
+ 0.19999999999999998
34
+ ]
35
+ ],
36
+ "rougeL": [
37
+ [
38
+ 0.14285714285714285,
39
+ 0.0909090909090909,
40
+ 0.1111111111111111
41
+ ],
42
+ [
43
+ 0.2857142857142857,
44
+ 0.2928030303030303,
45
+ 0.2855218855218855
46
+ ],
47
+ [
48
+ 0.5,
49
+ 0.6,
50
+ 0.5454545454545454
51
+ ]
52
+ ],
53
+ "rougeLsum": [
54
+ [
55
+ 0.14285714285714285,
56
+ 0.0909090909090909,
57
+ 0.1111111111111111
58
+ ],
59
+ [
60
+ 0.2857142857142857,
61
+ 0.2928030303030303,
62
+ 0.2855218855218855
63
+ ],
64
+ [
65
+ 0.5,
66
+ 0.6,
67
+ 0.5454545454545454
68
+ ]
69
+ ],
70
+ "bleu": 0.0,
71
+ "precisions": [
72
+ 0.3333333333333333,
73
+ 0.06666666666666667,
74
+ 0.037037037037037035,
75
+ 0.0
76
+ ],
77
+ "brevity_penalty": 0.9131007162822622,
78
+ "length_ratio": 0.9166666666666666,
79
+ "translation_length": 33,
80
+ "reference_length": 36,
81
+ "factuality": 0.3255411255411255
82
+ }
requirements.txt CHANGED
@@ -8,3 +8,9 @@ gradio==4.44.0
8
  tokenizers==0.19.1
9
  # Explicitly pin pydantic major to avoid breaking gradio deps
10
  pydantic==2.7.4
 
 
 
 
 
 
 
8
  tokenizers==0.19.1
9
  # Explicitly pin pydantic major to avoid breaking gradio deps
10
  pydantic==2.7.4
11
+ # Evaluation and datasets
12
+ datasets==2.21.0
13
+ evaluate==0.4.2
14
+ rouge-score==0.1.2
15
+ nltk==3.9.1
16
+ absl-py==2.1.0
retailgpt_evaluator/__pycache__/dataset_loader.cpython-311.pyc ADDED
Binary file (1.34 kB). View file
 
retailgpt_evaluator/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (3.39 kB). View file
 
retailgpt_evaluator/app.py CHANGED
@@ -19,10 +19,30 @@ def main():
19
  cfg = yaml.safe_load(f)
20
 
21
  # Show leaderboard if exists
22
- if os.path.exists("models/retail_eval_results.json"):
23
- df = build_leaderboard()
24
- st.subheader("πŸ“Š Model Leaderboard")
25
- st.dataframe(df, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  else:
27
  st.warning("Run `evaluate.py` first to generate metrics.")
28
 
 
19
  cfg = yaml.safe_load(f)
20
 
21
  # Show leaderboard if exists
22
+ leaderboard_df = None
23
+ results_path = Path("models/retail_eval_results.json")
24
+ if results_path.exists():
25
+ try:
26
+ leaderboard_df = build_leaderboard(results_path)
27
+ st.subheader("πŸ“Š Model Leaderboard")
28
+ st.dataframe(leaderboard_df, use_container_width=True)
29
+
30
+ st.markdown("#### πŸ“ˆ Evaluation Metrics")
31
+ metric_options = leaderboard_df["model"].tolist()
32
+ selected = st.selectbox("Inspect metrics for:", metric_options)
33
+ selected_row = leaderboard_df[leaderboard_df["model"] == selected].iloc[0]
34
+
35
+ cols = st.columns(4)
36
+ cols[0].metric("ROUGE-L", f"{selected_row['rougeL']:.3f}")
37
+ cols[1].metric("BLEU", f"{selected_row['bleu']:.3f}")
38
+ cols[2].metric("Factuality", f"{selected_row['factuality']:.3f}")
39
+ cols[3].metric("Score (avg)", f"{selected_row['score']:.3f}")
40
+
41
+ st.bar_chart(
42
+ leaderboard_df.set_index("model")[["rougeL", "bleu", "factuality", "score"]]
43
+ )
44
+ except Exception as exc: # pragma: no cover - defensive UI fallback
45
+ st.error(f"Unable to load evaluation results: {exc}")
46
  else:
47
  st.warning("Run `evaluate.py` first to generate metrics.")
48
 
retailgpt_evaluator/evaluate.py CHANGED
@@ -1,10 +1,19 @@
1
  import json
 
 
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from shared.metrics import compute_rouge, compute_bleu, factuality_score
5
- from shared.utils import print_banner
6
  import torch
7
 
 
 
 
 
 
 
 
 
 
8
  def run_eval_for_model(model_name, dataset):
9
  print_banner(f"Evaluating {model_name}")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -22,9 +31,14 @@ def run_eval_for_model(model_name, dataset):
22
  return {"model": model_name, **r, **b, **f}
23
 
24
  def evaluate_all():
25
- from shared.utils import load_yaml_config
26
- cfg = load_yaml_config("config.yaml")
27
- dataset = load_dataset("json", data_files="datasets/retail_sample.jsonl", split="train[:50]")
 
 
 
 
 
28
  results = [run_eval_for_model(m, dataset) for m in cfg["models"]]
29
  json.dump(results, open("models/retail_eval_results.json", "w"), indent=2)
30
  print("βœ… Saved results to models/retail_eval_results.json")
 
1
  import json
2
+ import sys
3
+ from pathlib import Path
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
6
  import torch
7
 
8
+ # Ensure repo root is on the path so `shared` package is found when run directly
9
+ ROOT = Path(__file__).resolve().parents[1]
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.insert(0, str(ROOT))
12
+
13
+ from shared.metrics import compute_rouge, compute_bleu, factuality_score # noqa: E402
14
+ from shared.utils import print_banner, load_yaml_config # noqa: E402
15
+ from retailgpt_evaluator.dataset_loader import load_retail_dataset # noqa: E402
16
+
17
  def run_eval_for_model(model_name, dataset):
18
  print_banner(f"Evaluating {model_name}")
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
31
  return {"model": model_name, **r, **b, **f}
32
 
33
  def evaluate_all():
34
+ config_path = Path(__file__).resolve().parent / "config.yaml"
35
+ cfg = load_yaml_config(config_path)
36
+ dataset_path = ROOT / "datasets/retail_sample.jsonl"
37
+ if not dataset_path.exists():
38
+ print_banner("Dataset not found. Creating sample dataset...")
39
+ load_retail_dataset()
40
+ dataset = load_dataset("json", data_files=str(dataset_path), split="train[:50]")
41
+ (ROOT / "models").mkdir(exist_ok=True)
42
  results = [run_eval_for_model(m, dataset) for m in cfg["models"]]
43
  json.dump(results, open("models/retail_eval_results.json", "w"), indent=2)
44
  print("βœ… Saved results to models/retail_eval_results.json")
shared/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Makes shared a package so imports like `from shared import ...` work.
shared/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (166 Bytes). View file
 
shared/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (2.27 kB). View file
 
shared/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
shared/hf_helpers.py CHANGED
@@ -1,10 +1,42 @@
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import torch
3
 
4
- def load_model_and_tokenizer(model_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """Load a model and tokenizer for inference."""
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
8
  return model, tokenizer
9
 
10
  def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
@@ -14,6 +46,14 @@ def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
14
  outputs = model.generate(**inputs, max_new_tokens=max_tokens)
15
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
16
 
17
- def build_pipeline(model_name: str, task="text2text-generation"):
18
  """Return a Hugging Face pipeline for inference."""
19
- return pipeline(task, model=model_name)
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  import torch
6
 
7
+ def _resolve_model_identifier(model_name: str) -> str:
8
+ """Return a valid model identifier or local path."""
9
+ path_candidate = Path(model_name)
10
+ if path_candidate.exists():
11
+ return str(path_candidate)
12
+ return model_name
13
+
14
+ def _build_hub_kwargs(token: Optional[str]) -> dict:
15
+ """Prepare kwargs for Hugging Face Hub auth across library versions."""
16
+ if not token:
17
+ return {}
18
+ return {"token": token}
19
+
20
+ def _fallback_hub_kwargs(token: Optional[str]) -> dict:
21
+ """Older transformers versions still expect use_auth_token."""
22
+ if not token:
23
+ return {}
24
+ return {"use_auth_token": token}
25
+
26
+ def load_model_and_tokenizer(model_name: str, token: Optional[str] = None):
27
  """Load a model and tokenizer for inference."""
28
+ resolved_model = _resolve_model_identifier(model_name)
29
+ auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
30
+
31
+ kwargs = _build_hub_kwargs(auth_token)
32
+ try:
33
+ tokenizer = AutoTokenizer.from_pretrained(resolved_model, **kwargs)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **kwargs)
35
+ except TypeError:
36
+ fallback_kwargs = _fallback_hub_kwargs(auth_token)
37
+ tokenizer = AutoTokenizer.from_pretrained(resolved_model, **fallback_kwargs)
38
+ model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **fallback_kwargs)
39
+
40
  return model, tokenizer
41
 
42
  def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256):
 
46
  outputs = model.generate(**inputs, max_new_tokens=max_tokens)
47
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
48
 
49
+ def build_pipeline(model_name: str, task="text2text-generation", token: Optional[str] = None):
50
  """Return a Hugging Face pipeline for inference."""
51
+ resolved_model = _resolve_model_identifier(model_name)
52
+ auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
53
+
54
+ kwargs = _build_hub_kwargs(auth_token)
55
+ try:
56
+ return pipeline(task, model=resolved_model, **kwargs)
57
+ except TypeError:
58
+ fallback_kwargs = _fallback_hub_kwargs(auth_token)
59
+ return pipeline(task, model=resolved_model, **fallback_kwargs)
shared/metrics.py CHANGED
@@ -6,9 +6,11 @@ def compute_rouge(preds, refs):
6
  return rouge.compute(predictions=preds, references=refs)
7
 
8
  def compute_bleu(preds, refs):
 
9
  bleu = load_metric("bleu")
10
- refs = [[r] for r in refs] # bleu expects list of lists
11
- return bleu.compute(predictions=preds, references=refs)
 
12
 
13
  def factuality_score(preds, refs):
14
  """Very simple lexical overlap metric for factual alignment."""
 
6
  return rouge.compute(predictions=preds, references=refs)
7
 
8
  def compute_bleu(preds, refs):
9
+ """BLEU with simple whitespace tokenization for compatibility."""
10
  bleu = load_metric("bleu")
11
+ pred_tokens = [p.split() for p in preds]
12
+ ref_tokens = [[r.split()] for r in refs] # bleu expects list of list of token lists
13
+ return bleu.compute(predictions=pred_tokens, references=ref_tokens)
14
 
15
  def factuality_score(preds, refs):
16
  """Very simple lexical overlap metric for factual alignment."""