sunhill commited on
Commit
5689a44
·
1 Parent(s): 03c869a

regular input

Browse files
Files changed (2) hide show
  1. spice.py +62 -61
  2. tests.py +12 -6
spice.py CHANGED
@@ -137,12 +137,20 @@ class SPICE(evaluate.Metric):
137
  citation=_CITATION,
138
  inputs_description=_KWARGS_DESCRIPTION,
139
  # This defines the format of each prediction and reference
140
- features=datasets.Features(
141
- {
142
- "predictions": datasets.List(datasets.Value("string")),
143
- "references": datasets.List(datasets.Value("string")),
144
- }
145
- ),
 
 
 
 
 
 
 
 
146
  # Homepage of the module for documentation
147
  homepage="https://huggingface.co/spaces/sunhill/spice",
148
  # Additional links to the codebase or references
@@ -182,51 +190,42 @@ class SPICE(evaluate.Metric):
182
 
183
  def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]:
184
  """Compute average scores over all images in the batch."""
185
- aggregate_scores = {}
 
 
 
 
 
 
 
 
 
 
186
  num_images = len(scores)
187
  if num_images == 0:
188
  return aggregate_scores
189
 
190
- # Initialize aggregate_scores with zero values
191
- for category in scores[0].keys():
192
- aggregate_scores[category] = {
193
- "pr": 0.0,
194
- "re": 0.0,
195
- "f": 0.0,
196
- "fn": 0.0,
197
- "numImages": 0.0,
198
- "fp": 0.0,
199
- "tp": 0.0,
200
- }
201
-
202
  # Sum up scores for each category
203
  for score in scores:
204
- for category, score_dict in score.items():
205
- for k, v in score_dict.items():
206
- if k in ["fn", "fp", "tp"]:
207
- aggregate_scores[category][k] += v
208
- aggregate_scores[category]["numImages"] += 1
209
 
210
  # Compute average scores
211
- for category, score_dict in aggregate_scores.items():
212
- tp = score_dict["tp"]
213
- fp = score_dict["fp"]
214
- fn = score_dict["fn"]
215
-
216
- precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
217
- recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
218
- f_score = (
219
- 2 * precision * recall / (precision + recall)
220
- if precision is not None
221
- and recall is not None
222
- and (precision + recall) > 0
223
- else float("nan")
224
- )
225
-
226
- aggregate_scores[category]["pr"] = precision
227
- aggregate_scores[category]["re"] = recall
228
- aggregate_scores[category]["f"] = f_score
229
-
230
  return aggregate_scores
231
 
232
  def _compute(self, predictions, references, spice_name="All"):
@@ -237,11 +236,19 @@ class SPICE(evaluate.Metric):
237
  )
238
  input_data = []
239
  for i, (prediction, reference) in enumerate(zip(predictions, references)):
240
- assert len(prediction) == 1 and len(reference) >= 1, (
241
- "SPICE expects exactly one prediction and at least one reference per image. "
242
- f"Got {len(prediction)} predictions and {len(reference)} references."
 
 
 
 
 
 
 
 
243
  )
244
- input_data.append({"image_id": i, "test": prediction[0], "refs": reference})
245
 
246
  in_file = tempfile.NamedTemporaryFile(delete=False)
247
  in_file.write(json.dumps(input_data, indent=2).encode("utf-8"))
@@ -281,17 +288,11 @@ class SPICE(evaluate.Metric):
281
  os.remove(in_file.name)
282
  os.remove(out_file.name)
283
 
284
- img_id_to_scores = {item["image_id"]: item["scores"] for item in results}
285
- scores = []
286
- for image_id in range(len(predictions)):
287
- # Convert none to NaN before saving scores over subcategories
288
- score_set = {}
289
- for category, score_tuple in img_id_to_scores[image_id].items():
290
- score_set[category] = {
291
- k: self.float_convert(v) for k, v in score_tuple.items()
292
- }
293
- scores.append(score_set)
294
- result_score = {}
295
- for k, v in self._compute_batch(scores)[spice_name].items():
296
- result_score["spice_" + spice_name.lower() + "_" + k] = v
297
- return result_score
 
137
  citation=_CITATION,
138
  inputs_description=_KWARGS_DESCRIPTION,
139
  # This defines the format of each prediction and reference
140
+ features=[
141
+ datasets.Features(
142
+ {
143
+ "predictions": datasets.Value("string"),
144
+ "references": datasets.Value("string"),
145
+ }
146
+ ),
147
+ datasets.Features(
148
+ {
149
+ "predictions": datasets.Value("string"),
150
+ "references": datasets.Sequence(datasets.Value("string")),
151
+ }
152
+ ),
153
+ ],
154
  # Homepage of the module for documentation
155
  homepage="https://huggingface.co/spaces/sunhill/spice",
156
  # Additional links to the codebase or references
 
190
 
191
  def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]:
192
  """Compute average scores over all images in the batch."""
193
+
194
+ # Initialize aggregate_scores with zero values
195
+ aggregate_scores = {
196
+ "pr": 0.0,
197
+ "re": 0.0,
198
+ "f": 0.0,
199
+ "fn": 0.0,
200
+ "numImages": 0.0,
201
+ "fp": 0.0,
202
+ "tp": 0.0,
203
+ }
204
  num_images = len(scores)
205
  if num_images == 0:
206
  return aggregate_scores
207
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  # Sum up scores for each category
209
  for score in scores:
210
+ for k, v in score.items():
211
+ if k in ["fn", "fp", "tp"]:
212
+ aggregate_scores[k] += v
213
+ aggregate_scores["numImages"] += 1
 
214
 
215
  # Compute average scores
216
+ tp = aggregate_scores["tp"]
217
+ fp = aggregate_scores["fp"]
218
+ fn = aggregate_scores["fn"]
219
+ precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
220
+ recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
221
+ f_score = (
222
+ 2 * precision * recall / (precision + recall)
223
+ if precision is not None and recall is not None and (precision + recall) > 0
224
+ else float("nan")
225
+ )
226
+ aggregate_scores["pr"] = precision
227
+ aggregate_scores["re"] = recall
228
+ aggregate_scores["f"] = f_score
 
 
 
 
 
 
229
  return aggregate_scores
230
 
231
  def _compute(self, predictions, references, spice_name="All"):
 
236
  )
237
  input_data = []
238
  for i, (prediction, reference) in enumerate(zip(predictions, references)):
239
+ assert isinstance(prediction, str), (
240
+ "Each prediction should be a string. "
241
+ f"Got {type(prediction)} for image {i}."
242
+ )
243
+ if isinstance(reference, str):
244
+ reference = [reference]
245
+ assert isinstance(reference, list) and all(
246
+ isinstance(ref, str) for ref in reference
247
+ ), (
248
+ "Each reference should be a list of strings. "
249
+ f"Got {type(reference)} with elements of type {[type(ref) for ref in reference]} for index {i}."
250
  )
251
+ input_data.append({"image_id": i, "test": prediction, "refs": reference})
252
 
253
  in_file = tempfile.NamedTemporaryFile(delete=False)
254
  in_file.write(json.dumps(input_data, indent=2).encode("utf-8"))
 
288
  os.remove(in_file.name)
289
  os.remove(out_file.name)
290
 
291
+ img_id_to_scores = {
292
+ item["image_id"]: item["scores"][spice_name] for item in results
293
+ }
294
+ scores = [
295
+ {k: self.float_convert(v) for k, v in img_id_to_scores[image_id].items()}
296
+ for image_id in range(len(predictions))
297
+ ]
298
+ return {f"spice_{k}": v for k, v in self._compute_batch(scores).items()}
 
 
 
 
 
 
tests.py CHANGED
@@ -3,7 +3,7 @@ import evaluate
3
 
4
  test_cases = [
5
  {
6
- "predictions": [["train traveling down a track in front of a road"]],
7
  "references": [
8
  [
9
  "a train traveling down tracks next to lights",
@@ -12,12 +12,18 @@ test_cases = [
12
  "a passenger train pulls into a train station",
13
  "a train coming down the tracks arriving at a station",
14
  ]
15
- ]
 
 
 
 
 
 
16
  },
17
  {
18
  "predictions": [
19
- ["plane is flying through the sky"],
20
- ["birthday cake sitting on top of a white plate"],
21
  ],
22
  "references": [
23
  [
@@ -28,7 +34,7 @@ test_cases = [
28
  "the plane is flying over top of the cars",
29
  ],
30
  ["a blue plate filled with marshmallows chocolate chips and banana"],
31
- ]
32
  },
33
  ]
34
 
@@ -37,7 +43,7 @@ for i, test_case in enumerate(test_cases):
37
  results = metric.compute(
38
  predictions=test_case["predictions"], references=test_case["references"]
39
  )
40
- print(f"Test case {i+1}:")
41
  print("Predictions:", test_case["predictions"])
42
  print("References:", test_case["references"])
43
  print(results)
 
3
 
4
  test_cases = [
5
  {
6
+ "predictions": ["train traveling down a track in front of a road"],
7
  "references": [
8
  [
9
  "a train traveling down tracks next to lights",
 
12
  "a passenger train pulls into a train station",
13
  "a train coming down the tracks arriving at a station",
14
  ]
15
+ ],
16
+ },
17
+ {
18
+ "predictions": ["birthday cake sitting on top of a white plate"],
19
+ "references": [
20
+ "a blue plate filled with marshmallows chocolate chips and banana"
21
+ ],
22
  },
23
  {
24
  "predictions": [
25
+ "plane is flying through the sky",
26
+ "birthday cake sitting on top of a white plate",
27
  ],
28
  "references": [
29
  [
 
34
  "the plane is flying over top of the cars",
35
  ],
36
  ["a blue plate filled with marshmallows chocolate chips and banana"],
37
+ ],
38
  },
39
  ]
40
 
 
43
  results = metric.compute(
44
  predictions=test_case["predictions"], references=test_case["references"]
45
  )
46
+ print(f"Test case {i + 1}:")
47
  print("Predictions:", test_case["predictions"])
48
  print("References:", test_case["references"])
49
  print(results)