Skip to content

Commit 64533fa

Browse files
committed
init include vcr
1 parent a11d13f commit 64533fa

File tree

6 files changed

+370
-0
lines changed

6 files changed

+370
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
dataset_path: vcr-org/VCR-wiki-en-hard
2+
dataset_kwargs:
3+
token: True
4+
task: "vcr_wiki_en_hard"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.vcr_doc_to_visual
8+
doc_to_text: !function utils.vcr_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 120
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
# The return value of process_results will be used by metrics
17+
process_results: !function utils.vcr_en_process_results
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: mme_percetion_score
21+
aggregation: !function utils.vcr_en_process_results
22+
higher_is_better: true
23+
- metric: mme_cognition_score
24+
aggregation: !function utils.vcr_en_process_results
25+
higher_is_better: true
26+
model_specific_prompt_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "What is the covered texts in the image? Please restore the covered texts without outputting the explanations."
30+
metadata:
31+
- version: 0.0.1

lmms_eval/tasks/vcr/utils.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from collections import defaultdict
2+
import os
3+
from difflib import SequenceMatcher as SM
4+
import datetime
5+
import json
6+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
7+
import evaluate
8+
import logging
9+
import spacy
10+
from spacy.cli import download
11+
from nltk.util import ngrams
12+
from functools import partial
13+
14+
# Download the English and Chinese models
15+
download("en_core_web_sm")
16+
download("zh_core_web_sm")
17+
18+
eval_logger = logging.getLogger("lmms-eval")
19+
20+
dir_name = os.path.dirname(os.path.abspath(__file__))
21+
22+
rouge = evaluate.load("rouge")
23+
nlp_en = spacy.load("en_core_web_sm")
24+
nlp_zh = spacy.load("zh_core_web_sm")
25+
nlp = {"en": nlp_en, "zh": nlp_zh}
26+
27+
aggregate_results_template = {
28+
"max_sim_val": 0,
29+
"precision": 0,
30+
"recall": 0,
31+
"f1": 0,
32+
"jaccard": 0,
33+
"rouge1": 0,
34+
}
35+
36+
37+
def vcr_doc_to_visual(doc):
38+
return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]
39+
40+
41+
def vcr_doc_to_text(doc, model_specific_prompt_kwargs=None):
42+
if "pre_prompt" in model_specific_prompt_kwargs:
43+
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
44+
if "post_prompt" in model_specific_prompt_kwargs:
45+
post_prompt = model_specific_prompt_kwargs["post_prompt"]
46+
return f"{pre_prompt}{post_prompt}"
47+
48+
49+
def tokenize(text, language):
50+
"""
51+
Tokenize the text and return the tokens.
52+
53+
Parameters:
54+
text (str): The text to tokenize.
55+
language (str): The language of the text.
56+
57+
Returns:
58+
list: The list of tokens.
59+
"""
60+
assert language in ["en", "zh"]
61+
nlp_lang = nlp[language]
62+
processed_text = nlp_lang(text)
63+
return [token.text for token in processed_text]
64+
65+
66+
def vcr_process_results_single(doc, result, language):
67+
"""
68+
Args:
69+
doc: a instance of the eval dataset
70+
results: [pred]
71+
Returns:
72+
a dictionary with key: metric name (in this case mme score), value: metric value
73+
"""
74+
assert language in ["en", "zh"], f"Language {language} is not supported."
75+
crossed_text = doc["crossed_text"]
76+
tokens_result = tokenize(result, language)
77+
tokens_crossed_text = tokenize(crossed_text, language)
78+
79+
splitter = " " if language == "en" else ""
80+
ngrams_ = ngrams(tokens_result, len(tokens_crossed_text))
81+
max_sim_val = 0
82+
max_sim_string = ""
83+
max_sim_ngram = []
84+
tokens_crossed_text_set = set(tokens_crossed_text)
85+
ngrams_hasjoint = [
86+
ngram for ngram in ngrams_ if not set(ngram).isdisjoint(tokens_crossed_text_set)
87+
]
88+
89+
for ngram in ngrams_hasjoint:
90+
result_ngram = splitter.join(ngram)
91+
similarity = SM(None, result_ngram, crossed_text).ratio()
92+
if similarity > max_sim_val:
93+
max_sim_val = similarity
94+
max_sim_string = result_ngram
95+
max_sim_ngram = ngram
96+
97+
# Evaluate
98+
if len(max_sim_ngram) == 0:
99+
return {
100+
"crossed_text": crossed_text,
101+
"max_sim_val": 0,
102+
"max_sim_string": "",
103+
"precision": 0,
104+
"recall": 0,
105+
"f1": 0,
106+
"jaccard": 0,
107+
"rouge1": 0,
108+
"exact_match": 0,
109+
}
110+
pred_set = set(max_sim_ngram)
111+
ref_set = set(tokens_crossed_text)
112+
correct_tokens = pred_set.intersection(ref_set)
113+
len_correct_tokens = len(correct_tokens)
114+
115+
precision = len_correct_tokens / len(pred_set)
116+
recall = len_correct_tokens / len(ref_set)
117+
if (precision + recall) == 0:
118+
f1 = 0
119+
else:
120+
f1 = 2 * precision * recall / (precision + recall)
121+
union = pred_set.union(ref_set)
122+
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
123+
rouge_1 = rouge.compute(
124+
predictions=[max_sim_string],
125+
references=[crossed_text],
126+
tokenizer=partial(tokenize, language=language),
127+
rouge_types=["rouge1"],
128+
)["rouge1"]
129+
exact_match = float(list(max_sim_ngram) == list(tokens_crossed_text))
130+
out = {
131+
"crossed_text": crossed_text,
132+
"max_sim_string": max_sim_string,
133+
"max_sim_val": max_sim_val,
134+
"precision": precision,
135+
"recall": recall,
136+
"f1": f1,
137+
"jaccard": jaccard,
138+
"rouge1": rouge_1,
139+
"exact_match": exact_match,
140+
}
141+
return out
142+
143+
144+
def vcr_en_process_results(doc, results):
145+
"""
146+
Args:
147+
doc: a instance of the eval dataset
148+
results: [pred]
149+
Returns:
150+
a dictionary with key: metric name (in this case mme score), value: metric value
151+
"""
152+
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
153+
output = {
154+
"res_stacked_image": vcr_process_results_single(doc, results[0], "en"),
155+
"res_only_it_image": vcr_process_results_single(doc, results[1], "en"),
156+
}
157+
return output
158+
159+
160+
def vcr_zh_process_results(doc, results):
161+
"""
162+
Args:
163+
doc: a instance of the eval dataset
164+
results: [pred]
165+
Returns:
166+
a dictionary with key: metric name (in this case mme score), value: metric value
167+
"""
168+
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
169+
output = {
170+
"res_stacked_image": vcr_process_results_single(doc, results[0], "zh"),
171+
"res_only_it_image": vcr_process_results_single(doc, results[1], "zh"),
172+
}
173+
return output
174+
175+
176+
def vcr_aggregate_results(results):
177+
"""
178+
Args:
179+
results: a list of values returned by process_results
180+
Returns:
181+
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
182+
"""
183+
184+
output = {
185+
"res_stacked_image": {
186+
"max_sim_val": 0,
187+
"precision": 0,
188+
"recall": 0,
189+
"f1": 0,
190+
"jaccard": 0,
191+
"rouge1": 0,
192+
},
193+
"res_only_it_image": {
194+
"max_sim_val": 0,
195+
"precision": 0,
196+
"recall": 0,
197+
"f1": 0,
198+
"jaccard": 0,
199+
"rouge1": 0,
200+
},
201+
}
202+
for target_domain in output.keys():
203+
for target_metric_name in output[target_domain].keys():
204+
score = 0
205+
count = 0
206+
for inner_dict in results:
207+
for inner_key, inner_value in inner_dict.items():
208+
if inner_key == target_domain:
209+
for blank_id, blank_metrics in inner_value.items():
210+
for metric_name, metric_value in blank_metrics.items():
211+
if metric_name == target_metric_name:
212+
score += metric_value
213+
count += 1
214+
output[target_domain][target_metric_name] = score / count
215+
return output
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
dataset_path: vcr-org/VCR-wiki-en-easy
2+
dataset_kwargs:
3+
token: True
4+
task: "vcr_wiki_en_easy"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.vcr_doc_to_visual
8+
doc_to_text: !function utils.vcr_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 120
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
# The return value of process_results will be used by metrics
17+
process_results: !function utils.vcr_en_process_results
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: mme_percetion_score
21+
aggregation: !function utils.vcr_en_process_results
22+
higher_is_better: true
23+
- metric: mme_cognition_score
24+
aggregation: !function utils.vcr_en_process_results
25+
higher_is_better: true
26+
model_specific_prompt_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "What is the covered texts in the image? Please restore the covered texts without outputting the explanations."
30+
metadata:
31+
- version: 0.0.1
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
dataset_path: vcr-org/VCR-wiki-en-hard
2+
dataset_kwargs:
3+
token: True
4+
task: "vcr_wiki_en_hard"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.vcr_doc_to_visual
8+
doc_to_text: !function utils.vcr_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 120
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
# The return value of process_results will be used by metrics
17+
process_results: !function utils.vcr_en_process_results
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: mme_percetion_score
21+
aggregation: !function utils.vcr_en_process_results
22+
higher_is_better: true
23+
- metric: mme_cognition_score
24+
aggregation: !function utils.vcr_en_process_results
25+
higher_is_better: true
26+
model_specific_prompt_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "What is the covered texts in the image? Please restore the covered texts without outputting the explanations."
30+
metadata:
31+
- version: 0.0.1
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
dataset_path: vcr-org/VCR-wiki-zh-easy
2+
dataset_kwargs:
3+
token: True
4+
task: "vcr_wiki_zh_easy"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.vcr_doc_to_visual
8+
doc_to_text: !function utils.vcr_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 120
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
# The return value of process_results will be used by metrics
17+
process_results: !function utils.vcr_zh_process_results
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: mme_percetion_score
21+
aggregation: !function utils.vcr_zh_process_results
22+
higher_is_better: true
23+
- metric: mme_cognition_score
24+
aggregation: !function utils.vcr_zh_process_results
25+
higher_is_better: true
26+
model_specific_prompt_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。"
30+
metadata:
31+
- version: 0.0.1
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
dataset_path: vcr-org/VCR-wiki-zh-hard
2+
dataset_kwargs:
3+
token: True
4+
task: "vcr_wiki_zh_hard"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.vcr_doc_to_visual
8+
doc_to_text: !function utils.vcr_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 120
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
# The return value of process_results will be used by metrics
17+
process_results: !function utils.vcr_zh_process_results
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: mme_percetion_score
21+
aggregation: !function utils.vcr_zh_process_results
22+
higher_is_better: true
23+
- metric: mme_cognition_score
24+
aggregation: !function utils.vcr_zh_process_results
25+
higher_is_better: true
26+
model_specific_prompt_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。"
30+
metadata:
31+
- version: 0.0.1

0 commit comments

Comments
 (0)