Skip to content

Commit cee7671

Browse files

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2214
-129
lines changed

README.md

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,21 @@ We also provide the raw data exported from Weights & Biases for the detailed res
201201
- OKVQA Validation 2014 (ok_vqa_val2014)
202202
- POPE (pope)
203203
- RefCOCO (refcoco)
204-
- refcoco_seg_test
205-
- refcoco_seg_val
206-
- refcoco_seg_testA
207-
- refcoco_seg_testB
208-
- refcoco_bbox_test
209-
- refcoco_bbox_val
210-
- refcoco_bbox_testA
211-
- refcoco_bbox_testB
204+
- refcoco_seg
205+
- refcoco_seg_test
206+
- refcoco_seg_val
207+
- refcoco_seg_testA
208+
- refcoco_seg_testB
209+
- refcoco_bbox
210+
- refcoco_bbox_test
211+
- refcoco_bbox_val
212+
- refcoco_bbox_testA
213+
- refcoco_bbox_testB
214+
- refcoco_bbox_rec
215+
- refcoco_bbox_rec_test
216+
- refcoco_bbox_rec_val
217+
- refcoco_bbox_rec_testA
218+
- refcoco_bbox_rec_testB
212219
- RefCOCO+ (refcoco+)
213220
- refcoco+_seg
214221
- refcoco+_seg_val
@@ -218,11 +225,20 @@ We also provide the raw data exported from Weights & Biases for the detailed res
218225
- refcoco+_bbox_val
219226
- refcoco+_bbox_testA
220227
- refcoco+_bbox_testB
228+
- refcoco+_bbox_rec
229+
- refcoco+_bbox_rec_val
230+
- refcoco+_bbox_rec_testA
231+
- refcoco+_bbox_rec_testB
221232
- RefCOCOg (refcocog)
222-
- refcocog_seg_test
223-
- refcocog_seg_val
224-
- refcocog_bbox_test
225-
- refcocog_bbox_val
233+
- refcocog_seg
234+
- refcocog_seg_test
235+
- refcocog_seg_val
236+
- refcocog_bbox
237+
- refcocog_bbox_test
238+
- refcocog_bbox_val
239+
- refcocog_bbox_rec
240+
- refcocog_bbox_rec_test
241+
- refcocog_bbox_rec_val
226242
- ScienceQA (scienceqa_full)
227243
- ScienceQA Full (scienceqa)
228244
- ScienceQA IMG (scienceqa_img)

lmms_eval/api/task.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,22 @@ def _prepare_metric_and_aggregation(self):
678678

679679
@retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
680680
def download(self, dataset_kwargs=None) -> None:
681+
# If the dataset is a video dataset,
682+
# Recursively search whether their is a zip and unzip it to the huggingface home
683+
if dataset_kwargs is not None and "video" in dataset_kwargs and dataset_kwargs["video"]:
684+
hf_home = os.environ["HF_HOME"]
685+
cache_dir = dataset_kwargs["cache_dir"]
686+
dataset_kwargs.pop("cache_dir")
687+
cache_dir = os.path.join(hf_home, cache_dir)
688+
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset")
689+
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
690+
if not os.path.exists(cache_dir):
691+
for zip_file in zip_files:
692+
shutil.unpack_archive(zip_file, cache_dir)
693+
builder_script = dataset_kwargs["builder_script"]
694+
self.DATASET_PATH = os.path.join(cache_path, builder_script)
695+
dataset_kwargs.pop("video")
696+
dataset_kwargs.pop("builder_script")
681697
download_config = DownloadConfig()
682698
download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
683699
download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
@@ -687,12 +703,15 @@ def download(self, dataset_kwargs=None) -> None:
687703
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
688704
**dataset_kwargs if dataset_kwargs is not None else {},
689705
)
690-
self.dataset_no_image = datasets.load_dataset(
691-
path=self.DATASET_PATH,
692-
name=self.DATASET_NAME,
693-
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
694-
**dataset_kwargs if dataset_kwargs is not None else {},
695-
)
706+
if self.config.process_docs is not None:
707+
for split in self.dataset:
708+
if split in [
709+
self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split
710+
]:
711+
self.dataset[split] = self.config.process_docs(self.dataset[split])
712+
713+
# copy dataset, remove image features
714+
self.dataset_no_image = self.dataset.copy()
696715
for doc_name in self.dataset_no_image:
697716
remove_cols = []
698717
features = self.dataset_no_image[doc_name].features
@@ -725,20 +744,14 @@ def has_test_docs(self) -> bool:
725744

726745
def training_docs(self) -> datasets.Dataset:
727746
if self.has_training_docs():
728-
if self.config.process_docs is not None:
729-
return self.config.process_docs(self.dataset[self.config.training_split])
730747
return self.dataset[self.config.training_split]
731748

732749
def validation_docs(self) -> datasets.Dataset:
733750
if self.has_validation_docs():
734-
if self.config.process_docs is not None:
735-
return self.config.process_docs(self.dataset[self.config.validation_split])
736751
return self.dataset[self.config.validation_split]
737752

738753
def test_docs(self) -> datasets.Dataset:
739754
if self.has_test_docs():
740-
if self.config.process_docs is not None:
741-
return self.config.process_docs(self.dataset[self.config.test_split])
742755
return self.dataset[self.config.test_split]
743756

744757
def fewshot_docs(self):
@@ -973,6 +986,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
973986
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)
974987

975988
def process_results(self, doc, results):
989+
if self.OUTPUT_TYPE == "generate_until":
990+
results[0] = results[0].strip()
976991
if callable(self.config.process_results):
977992
return self.config.process_results(doc, results)
978993

lmms_eval/filters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lmms_eval.api.filter import FilterEnsemble
1+
from lmms_eval.api.filter import FilterEnsemble, Filter
22
from . import selection
33
from . import extraction
44
from . import transformation
@@ -13,6 +13,7 @@
1313
"lowercase": transformation.LowercaseFilter,
1414
"uppercase": transformation.UppercaseFilter,
1515
"map": transformation.MapFilter,
16+
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
1617
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
1718
# that takes an input and returns a scalar and then should select the max reward,
1819
# or should implement different filters for different ways of handling a reward model's inference.

lmms_eval/filters/extraction.py

Lines changed: 170 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,47 @@
11
import re
2-
2+
import sys
3+
import unicodedata
34
from lmms_eval.api.filter import Filter
45

56

7+
class WhitespaceFilter(Filter):
8+
""" """
9+
10+
def __init__(self) -> None:
11+
pass
12+
13+
def apply(self, resps, docs):
14+
def filter_set(inst):
15+
filtered_resp = []
16+
for resp in inst:
17+
if resp.startswith(" "):
18+
resp = resp[1:]
19+
20+
filtered_resp.append(resp)
21+
22+
return filtered_resp
23+
24+
filtered_resps = [filter_set(resp) for resp in resps]
25+
26+
return filtered_resps
27+
28+
629
class RegexFilter(Filter):
730
""" """
831

9-
def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
32+
def __init__(
33+
self,
34+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
35+
group_select=0,
36+
fallback: str = "[invalid]",
37+
) -> None:
1038
"""
1139
pass a string `regex` to run `re.compile(r"regex")` on.
1240
`fallback` defines the output returned if no matches for the regex are located.
1341
"""
1442
self.regex_pattern = regex_pattern
1543
self.regex = re.compile(regex_pattern)
44+
self.group_select = group_select
1645
self.fallback = fallback
1746

1847
def apply(self, resps, docs):
@@ -23,9 +52,12 @@ def apply(self, resps, docs):
2352
def filter_set(inst):
2453
filtered = []
2554
for resp in inst:
26-
match = self.regex.search(resp)
55+
match = self.regex.findall(resp)
2756
if match:
28-
match = match.group(1).strip()
57+
match = match[self.group_select]
58+
if isinstance(match, tuple):
59+
match = [m for m in match if m][0]
60+
match = match.strip()
2961
else:
3062
match = self.fallback
3163
filtered.append(match)
@@ -38,23 +70,145 @@ def filter_set(inst):
3870
return filtered_resps
3971

4072

41-
class WhitespaceFilter(Filter):
42-
""" """
73+
class MultiChoiceRegexFilter(RegexFilter):
74+
"""
75+
A filter used to extract a model's answer on multiple choice questions with
76+
letter answers. assumes each document has a "choices" field
77+
containing the list of answer choices and that the answer label symbols
78+
are of the form (A), (B), (C), ... or A, B, C.
79+
"""
4380

44-
def __init__(self) -> None:
45-
pass
81+
def __init__(
82+
self,
83+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
84+
group_select=0,
85+
fallback: str = "[invalid]",
86+
ignore_case=False,
87+
ignore_punctuation=False,
88+
regexes_to_ignore=None,
89+
) -> None:
90+
"""
91+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
92+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
93+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
94+
group_select: Selects the (group_select)th match from the findall result.
95+
ignore_case: Ignores the case during step 1 matching
96+
ignore_punctuation: Remove the punctuation during step 1 matching
97+
regexes_to_ignore: Remove these regexes during step 1 matching
98+
"""
99+
super().__init__(regex_pattern, group_select, fallback)
100+
self.ignore_case = ignore_case
101+
self.ignore_punctuation = ignore_punctuation
102+
self.regexes_to_ignore = regexes_to_ignore
46103

47104
def apply(self, resps, docs):
48-
def filter_set(inst):
49-
filtered_resp = []
50-
for resp in inst:
51-
if resp.startswith(" "):
52-
resp = resp[1:]
105+
# here, we assume we have a list, in which each element is
106+
# a list of model responses for some particular input/target pair.
107+
# so we process each of these (same input/target response sets)
108+
# independently (and keep them a list.)
53109

54-
filtered_resp.append(resp)
110+
def find_match(regex, resp, convert_dict={}):
111+
match = regex.findall(resp)
112+
if match:
113+
match = match[self.group_select]
114+
if isinstance(match, tuple):
115+
match = [m for m in match if m][0]
116+
match = match.strip()
117+
if match and match in convert_dict:
118+
match = convert_dict[match]
119+
return match
55120

56-
return filtered_resp
121+
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
57122

58-
filtered_resps = [filter_set(resp) for resp in resps]
123+
def filter_ignores(st):
124+
if self.regexes_to_ignore is not None:
125+
for s in self.regexes_to_ignore:
126+
st = re.sub(s, "", st)
127+
128+
if self.ignore_case:
129+
st = st.lower()
130+
131+
if self.ignore_punctuation:
132+
# https://stackoverflow.com/a/266162
133+
st = st.translate(punct_tbl)
134+
return st
135+
136+
filtered_resps = []
137+
138+
for r, doc in zip(resps, docs):
139+
fallback_regexes = []
140+
choice_to_alpha = {}
141+
next_alpha = "A"
142+
143+
without_paren_fallback_regexes = []
144+
without_paren_to_target = {}
145+
146+
choices = doc["choices"]
147+
for c in choices:
148+
m = filter_ignores(c.strip())
149+
fallback_regexes.append(f"{re.escape(m)}")
150+
choice_to_alpha[m] = f"({next_alpha})"
151+
152+
without_paren_fallback_regexes.append(next_alpha)
153+
without_paren_to_target[next_alpha] = f"({next_alpha})"
154+
155+
next_alpha = chr(ord(next_alpha) + 1)
156+
fallback_regex = re.compile("|".join(fallback_regexes))
157+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
158+
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")
159+
160+
filtered = []
161+
for resp in r:
162+
match = find_match(self.regex, resp)
163+
if not match:
164+
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
165+
if not match:
166+
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
167+
if not match:
168+
match = self.fallback
169+
filtered.append(match)
170+
filtered_resps.append(filtered)
59171

60172
return filtered_resps
173+
174+
175+
class ExtendedRegexFilter(RegexFilter):
176+
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
177+
178+
def __init__(
179+
self,
180+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
181+
group_select=0,
182+
fallback: str = "[invalid]",
183+
ignore_case=False,
184+
ignore_punctuation=False,
185+
regexes_to_ignore=None,
186+
) -> None:
187+
super().__init__(regex_pattern, group_select, fallback)
188+
self.ignore_case = ignore_case
189+
self.ignore_punctuation = ignore_punctuation
190+
self.regexes_to_ignore = regexes_to_ignore
191+
192+
def filter_ignores(self, st):
193+
if self.regexes_to_ignore is not None:
194+
for s in self.regexes_to_ignore:
195+
st = re.sub(s, "", st)
196+
197+
if self.ignore_case:
198+
st = st.lower()
199+
200+
if self.ignore_punctuation:
201+
# https://stackoverflow.com/a/266162
202+
st = st.translate(self.punct_tbl)
203+
return st
204+
205+
def find_match(self, regex, resp, convert_dict={}):
206+
match = regex.findall(resp)
207+
if match:
208+
match = match[self.group_select]
209+
if isinstance(match, tuple):
210+
match = [m for m in match if m][0]
211+
match = match.strip()
212+
if match and match in convert_dict:
213+
match = convert_dict[match]
214+
return match

lmms_eval/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
AVAILABLE_MODELS = {
44
"llava": "Llava",
5+
"llava_hf": "LlavaHf",
6+
"llava_sglang": "LlavaSglang",
57
"qwen_vl": "Qwen_VL",
68
"fuyu": "Fuyu",
79
"gpt4v": "GPT4V",

0 commit comments

Comments
 (0)