sunhill commited on
Commit
df29cfd
·
1 Parent(s): ecd5dc6

regular input

Browse files
Files changed (2) hide show
  1. cider.py +24 -11
  2. tests.py +13 -7
cider.py CHANGED
@@ -63,12 +63,20 @@ class CIDEr(evaluate.Metric):
63
  citation=_CITATION,
64
  inputs_description=_KWARGS_DESCRIPTION,
65
  # This defines the format of each prediction and reference
66
- features=datasets.Features(
67
- {
68
- "predictions": datasets.List((datasets.Value("string"))),
69
- "references": datasets.List(datasets.Value("string")),
70
- }
71
- ),
 
 
 
 
 
 
 
 
72
  # Homepage of the module for documentation
73
  homepage="https://huggingface.co/spaces/sunhill/cider",
74
  # Additional links to the codebase or references
@@ -84,10 +92,6 @@ class CIDEr(evaluate.Metric):
84
  ],
85
  )
86
 
87
- def _download_and_prepare(self, dl_manager):
88
- """Optional: download external resources useful to compute the scores"""
89
- pass
90
-
91
  def _compute(self, predictions, references):
92
  """Returns the scores"""
93
  assert len(predictions) == len(references), (
@@ -96,6 +100,15 @@ class CIDEr(evaluate.Metric):
96
  )
97
  cider_scorer = CiderScorer(n=4, sigma=6.0)
98
  for pred, ref in zip(predictions, references):
99
- cider_scorer += (pred[0], ref)
 
 
 
 
 
 
 
 
 
100
  score, _ = cider_scorer.compute_score()
101
  return {"cider_score": score.item()}
 
63
  citation=_CITATION,
64
  inputs_description=_KWARGS_DESCRIPTION,
65
  # This defines the format of each prediction and reference
66
+ features=[
67
+ datasets.Features(
68
+ {
69
+ "predictions": datasets.Value("string"),
70
+ "references": datasets.Value("string"),
71
+ }
72
+ ),
73
+ datasets.Features(
74
+ {
75
+ "predictions": datasets.Value("string"),
76
+ "references": datasets.Sequence(datasets.Value("string")),
77
+ }
78
+ ),
79
+ ],
80
  # Homepage of the module for documentation
81
  homepage="https://huggingface.co/spaces/sunhill/cider",
82
  # Additional links to the codebase or references
 
92
  ],
93
  )
94
 
 
 
 
 
95
  def _compute(self, predictions, references):
96
  """Returns the scores"""
97
  assert len(predictions) == len(references), (
 
100
  )
101
  cider_scorer = CiderScorer(n=4, sigma=6.0)
102
  for pred, ref in zip(predictions, references):
103
+ assert isinstance(pred, str), (
104
+ f"Each prediction should be a string. Got {type(pred)}."
105
+ )
106
+ if isinstance(ref, str):
107
+ ref = [ref]
108
+ assert isinstance(ref, list) and all(isinstance(r, str) for r in ref), (
109
+ "Each reference should be a list of strings. "
110
+ f"Got {type(ref)} with elements of type {[type(r) for r in ref]}."
111
+ )
112
+ cider_scorer += (pred, ref)
113
  score, _ = cider_scorer.compute_score()
114
  return {"cider_score": score.item()}
tests.py CHANGED
@@ -3,7 +3,7 @@ import evaluate
3
 
4
  test_cases = [
5
  {
6
- "predictions": [["train traveling down a track in front of a road"]],
7
  "references": [
8
  [
9
  "a train traveling down tracks next to lights",
@@ -12,12 +12,18 @@ test_cases = [
12
  "a passenger train pulls into a train station",
13
  "a train coming down the tracks arriving at a station",
14
  ]
15
- ]
 
 
 
 
 
 
16
  },
17
  {
18
  "predictions": [
19
- ["plane is flying through the sky"],
20
- ["birthday cake sitting on top of a white plate"],
21
  ],
22
  "references": [
23
  [
@@ -28,16 +34,16 @@ test_cases = [
28
  "the plane is flying over top of the cars",
29
  ],
30
  ["a blue plate filled with marshmallows chocolate chips and banana"],
31
- ]
32
  },
33
  ]
34
 
35
- metric = evaluate.load("sunhill/cider")
36
  for i, test_case in enumerate(test_cases):
37
  results = metric.compute(
38
  predictions=test_case["predictions"], references=test_case["references"]
39
  )
40
- print(f"Test case {i+1}:")
41
  print("Predictions:", test_case["predictions"])
42
  print("References:", test_case["references"])
43
  print(results)
 
3
 
4
  test_cases = [
5
  {
6
+ "predictions": ["train traveling down a track in front of a road"],
7
  "references": [
8
  [
9
  "a train traveling down tracks next to lights",
 
12
  "a passenger train pulls into a train station",
13
  "a train coming down the tracks arriving at a station",
14
  ]
15
+ ],
16
+ },
17
+ {
18
+ "predictions": ["birthday cake sitting on top of a white plate"],
19
+ "references": [
20
+ "a blue plate filled with marshmallows chocolate chips and banana"
21
+ ],
22
  },
23
  {
24
  "predictions": [
25
+ "plane is flying through the sky",
26
+ "birthday cake sitting on top of a white plate",
27
  ],
28
  "references": [
29
  [
 
34
  "the plane is flying over top of the cars",
35
  ],
36
  ["a blue plate filled with marshmallows chocolate chips and banana"],
37
+ ],
38
  },
39
  ]
40
 
41
+ metric = evaluate.load("./cider.py")
42
  for i, test_case in enumerate(test_cases):
43
  results = metric.compute(
44
  predictions=test_case["predictions"], references=test_case["references"]
45
  )
46
+ print(f"Test case {i + 1}:")
47
  print("Predictions:", test_case["predictions"])
48
  print("References:", test_case["references"])
49
  print(results)