|
|
"""This module implements the CIDEr metric for image captioning evaluation.""" |
|
|
|
|
|
import evaluate |
|
|
import datasets |
|
|
|
|
|
from .cider_scorer import CiderScorer |
|
|
|
|
|
_CITATION = """\ |
|
|
@InProceedings{Vedantam_2015_CVPR, |
|
|
author = {Vedantam, Ramakrishna and Lawrence Zitnick, C. and Parikh, Devi}, |
|
|
title = {CIDEr: Consensus-Based Image Description Evaluation}, |
|
|
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, |
|
|
month = {June}, |
|
|
year = {2015} |
|
|
} |
|
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
|
This is a metric to evaluate image captioning. It is based on the idea of |
|
|
measuring the consensus between a candidate image caption and a set of |
|
|
reference image captions written by humans. The CIDEr score is computed by |
|
|
comparing the n-grams of the candidate caption to the n-grams of the reference |
|
|
captions, and measuring how many n-grams are shared between the candidate and |
|
|
the references. The score is then normalized by the length of the candidate |
|
|
caption and the number of reference captions. |
|
|
""" |
|
|
|
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
|
CIDEr (Consensus-based Image Description Evaluation) is a metric for evaluating the quality of image captions. |
|
|
It measures how similar a generated caption is to a set of reference captions written by humans. |
|
|
Args: |
|
|
predictions: list of predictions to score. |
|
|
references: list of references for each prediction. |
|
|
Returns: |
|
|
score: CIDEr score. |
|
|
Examples: |
|
|
>>> metric = evaluate.load("sunhill/cider") |
|
|
>>> results = metric.compute( |
|
|
predictions=[['train traveling down a track in front of a road']], |
|
|
references=[ |
|
|
[ |
|
|
'a train traveling down tracks next to lights', |
|
|
'a blue and silver train next to train station and trees', |
|
|
'a blue train is next to a sidewalk on the rails', |
|
|
'a passenger train pulls into a train station', |
|
|
'a train coming down the tracks arriving at a station' |
|
|
] |
|
|
] |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
|
class CIDEr(evaluate.Metric): |
|
|
"""CIDEr metric.""" |
|
|
|
|
|
def _info(self): |
|
|
return evaluate.MetricInfo( |
|
|
|
|
|
module_type="metric", |
|
|
description=_DESCRIPTION, |
|
|
citation=_CITATION, |
|
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
|
|
features=[ |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string"), |
|
|
"references": datasets.Value("string"), |
|
|
} |
|
|
), |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string"), |
|
|
"references": datasets.Sequence(datasets.Value("string")), |
|
|
} |
|
|
), |
|
|
], |
|
|
|
|
|
homepage="https://huggingface.co/spaces/sunhill/cider", |
|
|
|
|
|
codebase_urls=[ |
|
|
"https://github.com/ramavedantam/cider", |
|
|
"https://github.com/EricWWWW/image-caption-metrics", |
|
|
], |
|
|
reference_urls=[ |
|
|
( |
|
|
"https://openaccess.thecvf.com/content_cvpr_2015/html/" |
|
|
"Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.html" |
|
|
) |
|
|
], |
|
|
) |
|
|
|
|
|
def _compute(self, predictions, references): |
|
|
"""Returns the scores""" |
|
|
assert len(predictions) == len(references), ( |
|
|
"The number of predictions and references should be the same. " |
|
|
f"Got {len(predictions)} predictions and {len(references)} references." |
|
|
) |
|
|
cider_scorer = CiderScorer(n=4, sigma=6.0) |
|
|
for pred, ref in zip(predictions, references): |
|
|
assert isinstance(pred, str), ( |
|
|
f"Each prediction should be a string. Got {type(pred)}." |
|
|
) |
|
|
if isinstance(ref, str): |
|
|
ref = [ref] |
|
|
assert isinstance(ref, list) and all(isinstance(r, str) for r in ref), ( |
|
|
"Each reference should be a list of strings. " |
|
|
f"Got {type(ref)} with elements of type {[type(r) for r in ref]}." |
|
|
) |
|
|
cider_scorer += (pred, ref) |
|
|
score, _ = cider_scorer.compute_score() |
|
|
return {"cider_score": score.item()} |
|
|
|