sunhill commited on
Commit
03c869a
·
1 Parent(s): 8ef39b1

compute batch result

Browse files
Files changed (1) hide show
  1. spice.py +55 -2
spice.py CHANGED
@@ -5,6 +5,7 @@ import shutil
5
  import subprocess
6
  import json
7
  import tempfile
 
8
 
9
  import evaluate
10
  import datasets
@@ -179,7 +180,56 @@ class SPICE(evaluate.Metric):
179
  except (ValueError, TypeError):
180
  return float("nan")
181
 
182
- def _compute(self, predictions, references):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  """Returns the scores"""
184
  assert len(predictions) == len(references), (
185
  "The number of predictions and references should be the same. "
@@ -241,4 +291,7 @@ class SPICE(evaluate.Metric):
241
  k: self.float_convert(v) for k, v in score_tuple.items()
242
  }
243
  scores.append(score_set)
244
- return scores
 
 
 
 
5
  import subprocess
6
  import json
7
  import tempfile
8
+ from typing import List, Dict
9
 
10
  import evaluate
11
  import datasets
 
180
  except (ValueError, TypeError):
181
  return float("nan")
182
 
183
+ def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]:
184
+ """Compute average scores over all images in the batch."""
185
+ aggregate_scores = {}
186
+ num_images = len(scores)
187
+ if num_images == 0:
188
+ return aggregate_scores
189
+
190
+ # Initialize aggregate_scores with zero values
191
+ for category in scores[0].keys():
192
+ aggregate_scores[category] = {
193
+ "pr": 0.0,
194
+ "re": 0.0,
195
+ "f": 0.0,
196
+ "fn": 0.0,
197
+ "numImages": 0.0,
198
+ "fp": 0.0,
199
+ "tp": 0.0,
200
+ }
201
+
202
+ # Sum up scores for each category
203
+ for score in scores:
204
+ for category, score_dict in score.items():
205
+ for k, v in score_dict.items():
206
+ if k in ["fn", "fp", "tp"]:
207
+ aggregate_scores[category][k] += v
208
+ aggregate_scores[category]["numImages"] += 1
209
+
210
+ # Compute average scores
211
+ for category, score_dict in aggregate_scores.items():
212
+ tp = score_dict["tp"]
213
+ fp = score_dict["fp"]
214
+ fn = score_dict["fn"]
215
+
216
+ precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
217
+ recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
218
+ f_score = (
219
+ 2 * precision * recall / (precision + recall)
220
+ if precision is not None
221
+ and recall is not None
222
+ and (precision + recall) > 0
223
+ else float("nan")
224
+ )
225
+
226
+ aggregate_scores[category]["pr"] = precision
227
+ aggregate_scores[category]["re"] = recall
228
+ aggregate_scores[category]["f"] = f_score
229
+
230
+ return aggregate_scores
231
+
232
+ def _compute(self, predictions, references, spice_name="All"):
233
  """Returns the scores"""
234
  assert len(predictions) == len(references), (
235
  "The number of predictions and references should be the same. "
 
291
  k: self.float_convert(v) for k, v in score_tuple.items()
292
  }
293
  scores.append(score_set)
294
+ result_score = {}
295
+ for k, v in self._compute_batch(scores)[spice_name].items():
296
+ result_score["spice_" + spice_name.lower() + "_" + k] = v
297
+ return result_score