Spaces:
Runtime error
Runtime error
added handling of multi dimension
Browse files- jaccard_similarity.py +18 -12
jaccard_similarity.py
CHANGED
|
@@ -77,17 +77,10 @@ class JaccardSimilarity(evaluate.Metric):
|
|
| 77 |
description=_DESCRIPTION,
|
| 78 |
citation=_CITATION,
|
| 79 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 80 |
-
features=datasets.Features(
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
}
|
| 85 |
-
if self.config_name == "multilabel"
|
| 86 |
-
else {
|
| 87 |
-
"predictions": datasets.Value("int32"),
|
| 88 |
-
"references": datasets.Value("int32"),
|
| 89 |
-
}
|
| 90 |
-
),
|
| 91 |
reference_urls=[
|
| 92 |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
|
| 93 |
"https://en.wikipedia.org/wiki/Jaccard_index"
|
|
@@ -95,7 +88,20 @@ class JaccardSimilarity(evaluate.Metric):
|
|
| 95 |
)
|
| 96 |
|
| 97 |
def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
return {
|
| 100 |
"jaccard_similarity": jaccard_score(
|
| 101 |
references,
|
|
|
|
| 77 |
description=_DESCRIPTION,
|
| 78 |
citation=_CITATION,
|
| 79 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 80 |
+
features=datasets.Features({
|
| 81 |
+
"predictions": datasets.Sequence(datasets.Value("int32")),
|
| 82 |
+
"references": datasets.Sequence(datasets.Value("int32")),
|
| 83 |
+
}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
reference_urls=[
|
| 85 |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
|
| 86 |
"https://en.wikipedia.org/wiki/Jaccard_index"
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
|
| 91 |
+
predictions = np.array(predictions)
|
| 92 |
+
references = np.array(references)
|
| 93 |
+
|
| 94 |
+
# Handle different input shapes
|
| 95 |
+
if predictions.ndim == 1 and references.ndim == 1:
|
| 96 |
+
# Binary or multiclass case
|
| 97 |
+
pass
|
| 98 |
+
elif predictions.ndim == 2 and references.ndim == 2:
|
| 99 |
+
# Multilabel case
|
| 100 |
+
if average == 'binary':
|
| 101 |
+
average = 'micro' # 'binary' doesn't make sense for multilabel
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError("Predictions and references should have the same shape")
|
| 104 |
+
|
| 105 |
return {
|
| 106 |
"jaccard_similarity": jaccard_score(
|
| 107 |
references,
|