sunhill commited on
Commit
92aad95
·
1 Parent(s): ec09ccb

update SPICE metric

Browse files
Files changed (7) hide show
  1. .gitignore +2 -1
  2. .spice.py +0 -95
  3. README.md +5 -5
  4. app.py +1 -1
  5. get_stanford_models.sh +0 -23
  6. spice.py +132 -53
  7. tests.py +42 -12
.gitignore CHANGED
@@ -218,4 +218,5 @@ __marimo__/
218
  # Custom additions
219
  lib/stanford*.jar
220
  !lib/
221
- **/.DS_Store
 
 
218
  # Custom additions
219
  lib/stanford*.jar
220
  !lib/
221
+ **/.DS_Store
222
+ .vscode/
.spice.py DELETED
@@ -1,95 +0,0 @@
1
- import os
2
- import subprocess
3
- import json
4
- import numpy as np
5
- import tempfile
6
-
7
- # Assumes spice.jar is in the same directory as spice.py. Change as needed.
8
- SPICE_JAR = "spice-1.0.jar"
9
- TEMP_DIR = "tmp"
10
- CACHE_DIR = "cache"
11
-
12
-
13
- class Spice:
14
- """
15
- Main Class to compute the SPICE metric
16
- """
17
-
18
- def float_convert(self, obj):
19
- try:
20
- return float(obj)
21
- except (ValueError, TypeError):
22
- return np.nan
23
-
24
- def compute_score(self, gts, res):
25
- assert sorted(gts.keys()) == sorted(res.keys())
26
- imgIds = sorted(gts.keys())
27
-
28
- # Prepare temp input file for the SPICE scorer
29
- input_data = []
30
- for id in imgIds:
31
- hypo = res[id]
32
- ref = gts[id]
33
-
34
- # Sanity check.
35
- assert type(hypo) is list
36
- assert len(hypo) == 1
37
- assert type(ref) is list
38
- assert len(ref) >= 1
39
-
40
- input_data.append({"image_id": id, "test": hypo[0], "refs": ref})
41
-
42
- cwd = os.path.dirname(os.path.abspath(__file__))
43
- temp_dir = os.path.join(cwd, TEMP_DIR)
44
- if not os.path.exists(temp_dir):
45
- os.makedirs(temp_dir)
46
- in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
47
- json.dump(input_data, in_file, indent=2)
48
- in_file.close()
49
-
50
- # Start job
51
- out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
52
- out_file.close()
53
- cache_dir = os.path.join(cwd, CACHE_DIR)
54
- if not os.path.exists(cache_dir):
55
- os.makedirs(cache_dir)
56
- spice_cmd = [
57
- "java",
58
- "-jar",
59
- "-Xmx8G",
60
- SPICE_JAR,
61
- in_file.name,
62
- "-cache",
63
- cache_dir,
64
- "-out",
65
- out_file.name,
66
- "-subset",
67
- "-silent",
68
- ]
69
- subprocess.check_call(spice_cmd, cwd=os.path.dirname(os.path.abspath(__file__)))
70
-
71
- # Read and process results
72
- with open(out_file.name) as data_file:
73
- results = json.load(data_file)
74
- os.remove(in_file.name)
75
- os.remove(out_file.name)
76
-
77
- imgId_to_scores = {}
78
- spice_scores = []
79
- for item in results:
80
- imgId_to_scores[item["image_id"]] = item["scores"]
81
- spice_scores.append(self.float_convert(item["scores"]["All"]["f"]))
82
- average_score = np.mean(np.array(spice_scores))
83
- scores = []
84
- for image_id in imgIds:
85
- # Convert none to NaN before saving scores over subcategories
86
- score_set = {}
87
- for category, score_tuple in imgId_to_scores[image_id].iteritems():
88
- score_set[category] = {
89
- k: self.float_convert(v) for k, v in score_tuple.items()
90
- }
91
- scores.append(score_set)
92
- return average_score, scores
93
-
94
- def method(self):
95
- return "SPICE"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -5,22 +5,22 @@ tags:
5
  - metric
6
  description: "TODO: add a description here"
7
  sdk: gradio
8
- sdk_version: 3.19.1
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
  # Metric Card for SPICE
14
 
15
- ***Module Card Instructions:*** *Fill out the following subsections. Feel free to take a look at existing metric cards if you'd like examples.*
16
 
17
  ## Metric Description
18
- *Give a brief overview of this metric, including what task(s) it is usually used for, if any.*
19
 
20
  ## How to Use
21
- *Give general statement of how to use the metric*
22
 
23
- *Provide simplest possible example for using the metric*
24
 
25
  ### Inputs
26
  *List all input arguments in the format below*
 
5
  - metric
6
  description: "TODO: add a description here"
7
  sdk: gradio
8
+ sdk_version: 5.45.0
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
  # Metric Card for SPICE
14
 
15
+ ***Module Card Instructions:*** *This module calculates the SPICE metric for evaluating image captioning models.*
16
 
17
  ## Metric Description
18
+ *SPICE (Semantic Propositional Image Caption Evaluation) is a metric for evaluating the quality of image captions. It measures the semantic similarity between the generated captions and a set of reference captions by analyzing the underlying semantic propositions.*
19
 
20
  ## How to Use
21
+ *To use the SPICE metric, you need to provide a set of generated captions and a set of reference captions. The metric will then compute the SPICE score based on the semantic similarity between the two sets of captions.*
22
 
23
+ *Here is a simple example of using the SPICE metric:*
24
 
25
  ### Inputs
26
  *List all input arguments in the format below*
app.py CHANGED
@@ -3,4 +3,4 @@ from evaluate.utils import launch_gradio_widget
3
 
4
 
5
  module = evaluate.load("sunhill/spice")
6
- launch_gradio_widget(module)
 
3
 
4
 
5
  module = evaluate.load("sunhill/spice")
6
+ launch_gradio_widget(module)
get_stanford_models.sh DELETED
@@ -1,23 +0,0 @@
1
- #!/usr/bin/env sh
2
- # This script downloads the Stanford CoreNLP models.
3
-
4
- CORENLP=stanford-corenlp-full-2015-12-09
5
- SPICELIB=lib
6
-
7
- DIR="$( cd "$(dirname "$0")" ; pwd -P )"
8
- cd $DIR
9
-
10
- echo "Downloading..."
11
-
12
- wget http://nlp.stanford.edu/software/$CORENLP.zip
13
-
14
- echo "Unzipping..."
15
-
16
- mkdir -p .tmp
17
- unzip $CORENLP.zip -d .tmp/
18
- mv .tmp/$CORENLP/stanford-corenlp-3.6.0.jar $SPICELIB/
19
- mv .tmp/$CORENLP/stanford-corenlp-3.6.0-models.jar $SPICELIB/
20
- rm -f stanford-corenlp-full-2015-12-09.zip
21
- rm -rf .tmp
22
-
23
- echo "Done."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spice.py CHANGED
@@ -1,68 +1,58 @@
1
- # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """TODO: Add a description here."""
15
 
16
  import evaluate
17
  import datasets
 
 
 
 
18
 
 
 
 
19
 
20
- # TODO: Add BibTeX citation
21
  _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
 
26
  }
27
  """
28
 
29
- # TODO: Add description of the module here
30
  _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
 
32
  """
33
 
34
-
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
  predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
  references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
  >>> print(results)
53
- {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class SPICE(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
63
 
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
  # This is the description that will appear on the modules page.
68
  module_type="metric",
@@ -70,26 +60,115 @@ class SPICE(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
 
 
77
  # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
  # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
 
 
 
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
85
  """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def _compute(self, predictions, references):
90
  """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements the SPICE metric."""
2
+
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import json
7
+ import tempfile
 
 
 
 
 
 
 
8
 
9
  import evaluate
10
  import datasets
11
+ import numpy as np
12
+ from evaluate.utils.logging import get_logger
13
+
14
+ logger = get_logger(__name__)
15
 
16
+ CORENLP = "stanford-corenlp-full-2015-12-09"
17
+ SPICELIB = "lib"
18
+ SPICE_JAR = "spice-1.0.jar"
19
 
 
20
  _CITATION = """\
21
+ @inproceedings{spice2016,
22
+ title = {SPICE: Semantic Propositional Image Caption Evaluation},
23
+ author = {Peter Anderson and Basura Fernando and Mark Johnson and Stephen Gould},
24
+ year = {2016},
25
+ booktitle = {ECCV}
26
  }
27
  """
28
 
 
29
  _DESCRIPTION = """\
30
+ This module is designed to evaluate the quality of image captions using the SPICE metric.
31
+ It compares generated captions with reference captions to assess their semantic similarity.
32
  """
33
 
 
 
34
  _KWARGS_DESCRIPTION = """
35
+ Compute SPICE score.
36
  Args:
37
  predictions: list of predictions to score. Each predictions
38
+ should be a string.
39
  references: list of reference for each prediction. Each
40
+ reference should be a string.
41
  Returns:
42
+ spice: SPICE score
 
43
  Examples:
44
+ >>> metric = evaluate.load("sunhill/spice")
45
+ >>> results = metric.compute(predictions=["a cat on a mat"], references=["a cat is on the mat"])
 
 
 
46
  >>> print(results)
47
+ {'spice': 0.5}
48
  """
49
 
 
 
 
50
 
51
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
52
  class SPICE(evaluate.Metric):
53
+ """This module implements the SPICE metric for evaluating image captioning models."""
54
 
55
  def _info(self):
 
56
  return evaluate.MetricInfo(
57
  # This is the description that will appear on the modules page.
58
  module_type="metric",
 
60
  citation=_CITATION,
61
  inputs_description=_KWARGS_DESCRIPTION,
62
  # This defines the format of each prediction and reference
63
+ features=datasets.Features(
64
+ {
65
+ "predictions": datasets.List(datasets.Value("string")),
66
+ "references": datasets.List(datasets.Value("string")),
67
+ }
68
+ ),
69
  # Homepage of the module for documentation
70
+ homepage="https://huggingface.co/spaces/sunhill/spice",
71
  # Additional links to the codebase or references
72
+ codebase_urls=[
73
+ "https://github.com/peteanderson80/SPICE",
74
+ "https://github.com/EricWWWW/image-caption-metrics",
75
+ ],
76
+ reference_urls=["https://panderson.me/spice"],
77
  )
78
 
79
  def _download_and_prepare(self, dl_manager):
80
  """Optional: download external resources useful to compute the scores"""
81
+ if os.path.exists("lib/stanford-corenlp-3.6.0-models.jar") and os.path.exists(
82
+ "lib/stanford-corenlp-3.6.0.jar"
83
+ ):
84
+ logger.info("`stanford-corenlp` already exists. Skip downloading.")
85
+ return
86
+ logger.info("Downloading `stanford-corenlp`...")
87
+ url = f"http://nlp.stanford.edu/software/{CORENLP}.zip"
88
+ extracted_path = dl_manager.download_and_extract(url)
89
+ tmp_path = os.path.join(extracted_path, CORENLP)
90
+ shutil.copyfile(
91
+ os.path.join(tmp_path, "stanford-corenlp-3.6.0-models.jar"),
92
+ os.path.join(SPICELIB, "stanford-corenlp-3.6.0-models.jar"),
93
+ )
94
+ shutil.copyfile(
95
+ os.path.join(tmp_path, "stanford-corenlp-3.6.0.jar"),
96
+ os.path.join(SPICELIB, "stanford-corenlp-3.6.0.jar"),
97
+ )
98
+ logger.info(f"`stanford-corenlp` has been downloaded to {SPICELIB}")
99
+
100
+ def float_convert(self, obj):
101
+ try:
102
+ return float(obj)
103
+ except (ValueError, TypeError):
104
+ return np.nan
105
 
106
  def _compute(self, predictions, references):
107
  """Returns the scores"""
108
+ assert len(predictions) == len(references), (
109
+ "The number of predictions and references should be the same. "
110
+ f"Got {len(predictions)} predictions and {len(references)} references."
111
+ )
112
+ input_data = []
113
+ for i, (prediction, reference) in enumerate(zip(predictions, references)):
114
+ assert len(prediction) == 1 and len(reference) >= 1, (
115
+ "SPICE expects exactly one prediction and at least one reference per image. "
116
+ f"Got {len(prediction)} predictions and {len(reference)} references."
117
+ )
118
+ input_data.append({"image_id": i, "test": prediction[0], "refs": reference})
119
+ print(prediction, reference)
120
+
121
+ in_file = tempfile.NamedTemporaryFile(delete=False)
122
+ json.dump(input_data, in_file, indent=2)
123
+ in_file.close()
124
+
125
+ out_file = tempfile.NamedTemporaryFile(delete=False)
126
+ out_file.close()
127
+ with tempfile.TemporaryDirectory() as cache_dir:
128
+ spice_cmd = [
129
+ "java",
130
+ "-jar",
131
+ "-Xmx8G",
132
+ SPICE_JAR,
133
+ in_file.name,
134
+ "-cache",
135
+ cache_dir,
136
+ "-out",
137
+ out_file.name,
138
+ "-subset",
139
+ "-silent",
140
+ ]
141
+ try:
142
+ subprocess.run(
143
+ spice_cmd,
144
+ check=True,
145
+ stdout=subprocess.PIPE,
146
+ stderr=subprocess.PIPE,
147
+ )
148
+ except subprocess.CalledProcessError as e:
149
+ raise RuntimeError(
150
+ f"SPICE command '{' '.join(spice_cmd)}' returned non-zero exit status {e.returncode}. "
151
+ f"stderr: {e.stderr.decode('utf-8')}"
152
+ ) from e
153
+
154
+ with open(out_file.name, "r") as f:
155
+ results = json.load(f)
156
+ os.remove(in_file.name)
157
+ os.remove(out_file.name)
158
+
159
+ img_id_to_scores = {}
160
+ spice_scores = []
161
+ for item in results:
162
+ img_id_to_scores[item["image_id"]] = item["scores"]
163
+ spice_scores.append(self.float_convert(item["scores"]["All"]["f"]))
164
+ average_score = np.mean(np.array(spice_scores))
165
+ scores = []
166
+ for image_id in range(len(predictions)):
167
+ # Convert none to NaN before saving scores over subcategories
168
+ score_set = {}
169
+ for category, score_tuple in img_id_to_scores[image_id].iteritems():
170
+ score_set[category] = {
171
+ k: self.float_convert(v) for k, v in score_tuple.items()
172
+ }
173
+ scores.append(score_set)
174
+ return average_score, scores
tests.py CHANGED
@@ -1,17 +1,47 @@
 
 
 
1
  test_cases = [
2
  {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
 
 
 
 
 
 
 
 
6
  },
7
  {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
 
 
 
 
 
 
 
16
  }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+
3
+
4
  test_cases = [
5
  {
6
+ "predictions": [["train traveling down a track in front of a road"]],
7
+ "references": [
8
+ [
9
+ "a train traveling down tracks next to lights",
10
+ "a blue and silver train next to train station and trees",
11
+ "a blue train is next to a sidewalk on the rails",
12
+ "a passenger train pulls into a train station",
13
+ "a train coming down the tracks arriving at a station",
14
+ ]
15
+ ],
16
+ "result": {"metric_score": 0},
17
  },
18
  {
19
+ "predictions": [
20
+ ["plane is flying through the sky"],
21
+ ["birthday cake sitting on top of a white plate"],
22
+ ],
23
+ "references": [
24
+ [
25
+ "a large jetliner flying over a traffic filled street",
26
+ "an airplane flies low in the sky over a city street",
27
+ "an airplane flies over a street with many cars",
28
+ "an airplane comes in to land over a road full of cars",
29
+ "the plane is flying over top of the cars",
30
+ ],
31
+ ["a blue plate filled with marshmallows chocolate chips and banana"],
32
+ ],
33
+ "result": {"metric_score": 1},
34
  }
35
+ ]
36
+
37
+ metric = evaluate.load("./spice.py")
38
+ for i, test_case in enumerate(test_cases):
39
+ results = metric.compute(
40
+ predictions=test_case["predictions"], references=test_case["references"]
41
+ )
42
+ print(f"Test case {i+1}:")
43
+ print("Predictions:", test_case["predictions"])
44
+ print("References:", test_case["references"])
45
+ print("Results:", results)
46
+ print("Expected:", test_case["result"])
47
+ print()