cider / cider.py
sunhill's picture
regular input
df29cfd
"""This module implements the CIDEr metric for image captioning evaluation."""
import evaluate
import datasets
from .cider_scorer import CiderScorer
_CITATION = """\
@InProceedings{Vedantam_2015_CVPR,
author = {Vedantam, Ramakrishna and Lawrence Zitnick, C. and Parikh, Devi},
title = {CIDEr: Consensus-Based Image Description Evaluation},
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2015}
}
"""
_DESCRIPTION = """\
This is a metric to evaluate image captioning. It is based on the idea of
measuring the consensus between a candidate image caption and a set of
reference image captions written by humans. The CIDEr score is computed by
comparing the n-grams of the candidate caption to the n-grams of the reference
captions, and measuring how many n-grams are shared between the candidate and
the references. The score is then normalized by the length of the candidate
caption and the number of reference captions.
"""
_KWARGS_DESCRIPTION = """
CIDEr (Consensus-based Image Description Evaluation) is a metric for evaluating the quality of image captions.
It measures how similar a generated caption is to a set of reference captions written by humans.
Args:
predictions: list of predictions to score.
references: list of references for each prediction.
Returns:
score: CIDEr score.
Examples:
>>> metric = evaluate.load("sunhill/cider")
>>> results = metric.compute(
predictions=[['train traveling down a track in front of a road']],
references=[
[
'a train traveling down tracks next to lights',
'a blue and silver train next to train station and trees',
'a blue train is next to a sidewalk on the rails',
'a passenger train pulls into a train station',
'a train coming down the tracks arriving at a station'
]
]
)
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class CIDEr(evaluate.Metric):
"""CIDEr metric."""
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=[
datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Sequence(datasets.Value("string")),
}
),
],
# Homepage of the module for documentation
homepage="https://huggingface.co/spaces/sunhill/cider",
# Additional links to the codebase or references
codebase_urls=[
"https://github.com/ramavedantam/cider",
"https://github.com/EricWWWW/image-caption-metrics",
],
reference_urls=[
(
"https://openaccess.thecvf.com/content_cvpr_2015/html/"
"Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.html"
)
],
)
def _compute(self, predictions, references):
"""Returns the scores"""
assert len(predictions) == len(references), (
"The number of predictions and references should be the same. "
f"Got {len(predictions)} predictions and {len(references)} references."
)
cider_scorer = CiderScorer(n=4, sigma=6.0)
for pred, ref in zip(predictions, references):
assert isinstance(pred, str), (
f"Each prediction should be a string. Got {type(pred)}."
)
if isinstance(ref, str):
ref = [ref]
assert isinstance(ref, list) and all(isinstance(r, str) for r in ref), (
"Each reference should be a list of strings. "
f"Got {type(ref)} with elements of type {[type(r) for r in ref]}."
)
cider_scorer += (pred, ref)
score, _ = cider_scorer.compute_score()
return {"cider_score": score.item()}