File size: 1,727 Bytes
ecd5dc6
 
 
9fa3156
 
df29cfd
ecd5dc6
 
 
 
 
 
 
 
df29cfd
 
 
 
 
 
 
9fa3156
 
ecd5dc6
df29cfd
 
ecd5dc6
 
 
 
 
 
 
 
 
 
df29cfd
9fa3156
ecd5dc6
 
df29cfd
ecd5dc6
 
 
 
df29cfd
ecd5dc6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import evaluate


test_cases = [
    {
        "predictions": ["train traveling down a track in front of a road"],
        "references": [
            [
                "a train traveling down tracks next to lights",
                "a blue and silver train next to train station and trees",
                "a blue train is next to a sidewalk on the rails",
                "a passenger train pulls into a train station",
                "a train coming down the tracks arriving at a station",
            ]
        ],
    },
    {
        "predictions": ["birthday cake sitting on top of a white plate"],
        "references": [
            "a blue plate filled with marshmallows chocolate chips and banana"
        ],
    },
    {
        "predictions": [
            "plane is flying through the sky",
            "birthday cake sitting on top of a white plate",
        ],
        "references": [
            [
                "a large jetliner flying over a traffic filled street",
                "an airplane flies low in the sky over a city street",
                "an airplane flies over a street with many cars",
                "an airplane comes in to land over a road full of cars",
                "the plane is flying over top of the cars",
            ],
            ["a blue plate filled with marshmallows chocolate chips and banana"],
        ],
    },
]

metric = evaluate.load("./cider.py")
for i, test_case in enumerate(test_cases):
    results = metric.compute(
        predictions=test_case["predictions"], references=test_case["references"]
    )
    print(f"Test case {i + 1}:")
    print("Predictions:", test_case["predictions"])
    print("References:", test_case["references"])
    print(results)