Skip to content

Commit 7cc2890

Browse files
authored
Merge pull request #113 from teowu/main
Q-Bench, Q-Bench2, A-Bench
2 parents 4bc7224 + ea14cd4 commit 7cc2890

File tree

6 files changed

+329
-5
lines changed

6 files changed

+329
-5
lines changed

lmms_eval/models/phi3v.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,16 @@ def _collate(x):
185185
contexts = list(contexts)
186186
for i in range(len(contexts)):
187187
if "<image>" in contexts[i]:
188-
query = contexts[i].replace("<image>", "<|image_1|>")
188+
query = "" + contexts[i]
189+
img_placeholder_count = 1
190+
while "<image>" in query:
191+
query = query.replace("<image>", f"<|image_{img_placeholder_count}|>", 1)
192+
img_placeholder_count += 1
189193
else:
190-
query = f"<|image_1|>\n{contexts[i]}"
194+
query = ""
195+
for placeholder_id in range(len(visuals)):
196+
query += f"<|image_{placeholder_id+1}|>\n"
197+
query += contexts[i]
191198
messages = [
192199
{"role": "user", "content": query}
193200
]
@@ -196,12 +203,11 @@ def _collate(x):
196203
tokenize=False,
197204
add_generation_prompt=True)
198205
assert len(contexts) == 1
199-
# We always pass a single image given that the model only accepts one image (as of 5/21/24).
206+
#
200207
context = contexts[0]
201-
pil_image = visuals[0]
202208
input_ids = self._processor(
203209
text=context,
204-
images=[pil_image],
210+
images=visuals,
205211
return_tensors="pt").to(self._device, self.model.dtype)
206212
# Setting default parameters.
207213
if "max_new_tokens" not in gen_kwargs:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataset_path: q-future/A-Bench-HF
2+
task: "abench_dev"
3+
test_split: dev
4+
output_type: generate_until
5+
doc_to_visual: !function utils.q_bench_doc_to_visual
6+
doc_to_text: !function utils.q_bench_doc_to_text
7+
doc_to_target: "correct_choice"
8+
generation_kwargs:
9+
max_new_tokens: 32
10+
temperature: 0
11+
do_sample: False
12+
process_results: !function utils.a_bench_process_results
13+
metric_list:
14+
- metric: abench_acc
15+
aggregation: !function utils.a_bench_aggregate_results
16+
higher_is_better: true
17+
18+
model_specific_prompt_kwargs:
19+
default:
20+
pre_prompt: ""
21+
post_prompt: "Answer with the option's letter from the given choices directly.\n"
22+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataset_path: q-future/Q-Bench2-HF
2+
task: "qbench2_dev"
3+
test_split: dev
4+
output_type: generate_until
5+
doc_to_visual: !function utils.q_bench_doc_to_visual
6+
doc_to_text: !function utils.q_bench_doc_to_text
7+
doc_to_target: "correct_choice"
8+
generation_kwargs:
9+
max_new_tokens: 32
10+
temperature: 0
11+
do_sample: False
12+
process_results: !function utils.q_bench_process_results
13+
metric_list:
14+
- metric: qbench_acc
15+
aggregation: !function utils.q_bench_aggregate_results
16+
higher_is_better: true
17+
18+
model_specific_prompt_kwargs:
19+
default:
20+
pre_prompt: ""
21+
post_prompt: "Answer with the option's letter from the given choices directly.\n"
22+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataset_path: q-future/Q-Bench-HF
2+
task: "qbench_dev"
3+
test_split: dev
4+
output_type: generate_until
5+
doc_to_visual: !function utils.q_bench_doc_to_visual
6+
doc_to_text: !function utils.q_bench_doc_to_text
7+
doc_to_target: "correct_choice"
8+
generation_kwargs:
9+
max_new_tokens: 32
10+
temperature: 0
11+
do_sample: False
12+
process_results: !function utils.q_bench_process_results
13+
metric_list:
14+
- metric: qbench_acc
15+
aggregation: !function utils.q_bench_aggregate_results
16+
higher_is_better: true
17+
18+
model_specific_prompt_kwargs:
19+
default:
20+
pre_prompt: ""
21+
post_prompt: "Answer with the option's letter from the given choices directly.\n"
22+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
group: qbenchs_dev
2+
task:
3+
- qbench_dev
4+
- qbench2_dev
5+
- abench_dev

lmms_eval/tasks/qbench/utils.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import json
2+
import logging
3+
import re
4+
from collections import Counter, defaultdict
5+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
6+
7+
8+
def q_bench_doc_to_text(doc, model_specific_prompt_kwargs):
9+
candidates = []
10+
for i in range(4):
11+
candidate = doc.get(f"option{i}")
12+
if candidate != "N/A":
13+
candidates.append(candidate)
14+
15+
question = doc["question"] + "\n" + "\n".join([". ".join([chr(ord("A")+i), candidate]) for i, candidate in enumerate(candidates)])
16+
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
17+
post_prompt = model_specific_prompt_kwargs["post_prompt"]
18+
return f"{pre_prompt}{question}\n{post_prompt}"
19+
20+
21+
def q_bench_doc_to_visual(doc):
22+
if "image2" not in doc:
23+
return [doc["image"].convert("RGB")]
24+
else:
25+
return [doc["image1"].convert("RGB"), doc["image2"].convert("RGB")]
26+
27+
28+
def get_multi_choice_info(options):
29+
"""
30+
Given the list of options for multiple choice question
31+
Return the index2ans and all_choices
32+
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/data_utils.py#L54
33+
"""
34+
35+
start_chr = "A"
36+
all_choices = []
37+
index2ans = {}
38+
for i, option in enumerate(options):
39+
index2ans[chr(ord(start_chr) + i)] = option
40+
all_choices.append(chr(ord(start_chr) + i))
41+
42+
return index2ans, all_choices
43+
44+
45+
def parse_multi_choice_response(response, all_choices, index2ans):
46+
"""
47+
Parse the prediction from the generated response.
48+
Return the predicted index e.g., A, B, C, D.
49+
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10
50+
"""
51+
for char in [",", ".", "!", "?", ";", ":", "'"]:
52+
response = response.strip(char)
53+
response = " " + response + " " # add space to avoid partial match
54+
55+
index_ans = True
56+
ans_with_brack = False
57+
candidates = []
58+
for choice in all_choices: # e.g., (A) (B) (C) (D)
59+
if f"({choice})" in response:
60+
candidates.append(choice)
61+
ans_with_brack = True
62+
63+
if len(candidates) == 0:
64+
for choice in all_choices: # e.g., A B C D
65+
if f"{choice} " in response:
66+
candidates.append(choice)
67+
68+
if len(candidates) == 0:
69+
for choice in all_choices: # e.g., A. B. C. D.
70+
if f"{choice}." in response:
71+
candidates.append(choice)
72+
73+
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
74+
if len(candidates) == 0 and len(response.split()) > 5:
75+
for index, ans in index2ans.items():
76+
if ans.lower() in response.lower():
77+
candidates.append(index)
78+
index_ans = False # it's content ans.
79+
80+
if len(candidates) == 0: # still not get answer, randomly choose one.
81+
pred_index = random.choice(all_choices)
82+
elif len(candidates) > 1:
83+
start_indexes = []
84+
if index_ans:
85+
if ans_with_brack:
86+
for can in candidates:
87+
index = response.rfind(f"({can})")
88+
start_indexes.append(index) # -1 will be ignored anyway
89+
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
90+
else:
91+
for can in candidates:
92+
index = response.rfind(f" {can} ")
93+
start_indexes.append(index)
94+
else:
95+
for can in candidates:
96+
index = response.lower().rfind(index2ans[can].lower())
97+
start_indexes.append(index)
98+
# get the last one
99+
pred_index = candidates[np.argmax(start_indexes)]
100+
else: # if only one candidate, use it.
101+
pred_index = candidates[0]
102+
103+
return pred_index
104+
105+
106+
def evaluate_q_bench(samples):
107+
pred_correct = 0
108+
judge_dict = dict()
109+
for sample in samples:
110+
gold_i = sample["answer"]
111+
pred_i = sample["parsed_pred"]
112+
correct = eval_multi_choice(gold_i, pred_i)
113+
114+
if correct:
115+
judge_dict[sample["id"]] = "Correct"
116+
pred_correct += 1
117+
else:
118+
judge_dict[sample["id"]] = "Wrong"
119+
120+
if len(samples) == 0:
121+
return {"acc": 0}
122+
return judge_dict, {"acc": pred_correct / len(samples)}
123+
124+
def eval_multi_choice(gold_i, pred_i):
125+
correct = False
126+
# only they are exactly the same, we consider it as correct
127+
if isinstance(gold_i, list):
128+
for answer in gold_i:
129+
if answer == pred_i:
130+
correct = True
131+
break
132+
else: # gold_i is a string
133+
if gold_i == pred_i:
134+
correct = True
135+
return correct
136+
137+
def calculate_ins_level_acc(results):
138+
"""Calculate the instruction level accuracy for given Subject results
139+
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L246
140+
"""
141+
acc = 0
142+
ins_num = 0
143+
for cat_results in results.values():
144+
acc += cat_results["acc"] * cat_results["num_example"]
145+
ins_num += cat_results["num_example"]
146+
if ins_num == 0:
147+
return 0
148+
return acc / ins_num
149+
150+
151+
def q_bench_process_results(doc, results):
152+
pred = results[0]
153+
all_choices = []
154+
index2ans = {}
155+
for i in range(4):
156+
option = doc.get(f"option{i}")
157+
if option == "N/A":
158+
break
159+
index2ans[chr(ord("A") + i)] = option
160+
all_choices.append(chr(ord("A") + i))
161+
162+
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
163+
id = doc["id"]
164+
qbench_acc = {"id": id, "question_concern": doc["question_concern"], "question_type": doc["question_type"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}
165+
return {
166+
"qbench_acc": qbench_acc,
167+
"submission": {
168+
id: pred,
169+
},
170+
}
171+
172+
173+
concern_list = ["Global Distortion", "Global Others", "Local Distortion", "Local Others"]
174+
question_list = ["Yes/No", "How", "What"]
175+
176+
def q_bench_aggregate_results(results):
177+
evaluation_result = {}
178+
subset_to_eval_samples = defaultdict(list)
179+
for result in results:
180+
subset_to_eval_samples[concern_list[result["question_concern"]]].append(result)
181+
subset_to_eval_samples[question_list[result["question_type"]]].append(result)
182+
for subset, sub_eval_samples in subset_to_eval_samples.items():
183+
judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)
184+
metric_dict.update({"num_example": len(sub_eval_samples)})
185+
evaluation_result[subset] = metric_dict
186+
printable_results = {}
187+
188+
for cat_name, cat_results in evaluation_result.items():
189+
printable_results[cat_name] = {
190+
"num": int(cat_results["num_example"]),
191+
"acc": round(cat_results["acc"], 5),
192+
}
193+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
194+
printable_results["Overall"] = {
195+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
196+
"acc": round(all_ins_acc, 5),
197+
}
198+
print(printable_results)
199+
return printable_results["Overall"]["acc"]
200+
201+
def a_bench_process_results(doc, results):
202+
pred = results[0]
203+
all_choices = []
204+
index2ans = {}
205+
for i in range(4):
206+
option = doc.get(f"option{i}")
207+
if option == "N/A":
208+
break
209+
index2ans[chr(ord("A") + i)] = option
210+
all_choices.append(chr(ord("A") + i))
211+
212+
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
213+
id = doc["id"]
214+
abench_acc = {"id": id, "category": doc["category"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}
215+
return {
216+
"abench_acc": abench_acc,
217+
"submission": {
218+
id: pred,
219+
},
220+
}
221+
222+
223+
224+
def a_bench_aggregate_results(results):
225+
evaluation_result = {}
226+
subset_to_eval_samples = defaultdict(list)
227+
for result in results:
228+
subset_to_eval_samples[result["category"]].append(result)
229+
for subset, sub_eval_samples in subset_to_eval_samples.items():
230+
judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)
231+
metric_dict.update({"num_example": len(sub_eval_samples)})
232+
evaluation_result[subset] = metric_dict
233+
printable_results = {}
234+
235+
for cat_name, cat_results in evaluation_result.items():
236+
printable_results[cat_name] = {
237+
"num": int(cat_results["num_example"]),
238+
"acc": round(cat_results["acc"], 5),
239+
}
240+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
241+
printable_results["Overall"] = {
242+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
243+
"acc": round(all_ins_acc, 5),
244+
}
245+
print(printable_results)
246+
return printable_results["Overall"]["acc"]
247+

0 commit comments

Comments
 (0)