11import re
2-
2+ import sys
3+ import unicodedata
34from lmms_eval .api .filter import Filter
45
56
7+ class WhitespaceFilter (Filter ):
8+ """ """
9+
10+ def __init__ (self ) -> None :
11+ pass
12+
13+ def apply (self , resps , docs ):
14+ def filter_set (inst ):
15+ filtered_resp = []
16+ for resp in inst :
17+ if resp .startswith (" " ):
18+ resp = resp [1 :]
19+
20+ filtered_resp .append (resp )
21+
22+ return filtered_resp
23+
24+ filtered_resps = [filter_set (resp ) for resp in resps ]
25+
26+ return filtered_resps
27+
28+
629class RegexFilter (Filter ):
730 """ """
831
9- def __init__ (self , regex_pattern : str = r"#### (\-?[0-9\.\,]+)" , fallback : str = "[invalid]" ) -> None :
32+ def __init__ (
33+ self ,
34+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
35+ group_select = 0 ,
36+ fallback : str = "[invalid]" ,
37+ ) -> None :
1038 """
1139 pass a string `regex` to run `re.compile(r"regex")` on.
1240 `fallback` defines the output returned if no matches for the regex are located.
1341 """
1442 self .regex_pattern = regex_pattern
1543 self .regex = re .compile (regex_pattern )
44+ self .group_select = group_select
1645 self .fallback = fallback
1746
1847 def apply (self , resps , docs ):
@@ -23,9 +52,12 @@ def apply(self, resps, docs):
2352 def filter_set (inst ):
2453 filtered = []
2554 for resp in inst :
26- match = self .regex .search (resp )
55+ match = self .regex .findall (resp )
2756 if match :
28- match = match .group (1 ).strip ()
57+ match = match [self .group_select ]
58+ if isinstance (match , tuple ):
59+ match = [m for m in match if m ][0 ]
60+ match = match .strip ()
2961 else :
3062 match = self .fallback
3163 filtered .append (match )
@@ -38,23 +70,145 @@ def filter_set(inst):
3870 return filtered_resps
3971
4072
41- class WhitespaceFilter (Filter ):
42- """ """
73+ class MultiChoiceRegexFilter (RegexFilter ):
74+ """
75+ A filter used to extract a model's answer on multiple choice questions with
76+ letter answers. assumes each document has a "choices" field
77+ containing the list of answer choices and that the answer label symbols
78+ are of the form (A), (B), (C), ... or A, B, C.
79+ """
4380
44- def __init__ (self ) -> None :
45- pass
81+ def __init__ (
82+ self ,
83+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
84+ group_select = 0 ,
85+ fallback : str = "[invalid]" ,
86+ ignore_case = False ,
87+ ignore_punctuation = False ,
88+ regexes_to_ignore = None ,
89+ ) -> None :
90+ """
91+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
92+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
93+ - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
94+ group_select: Selects the (group_select)th match from the findall result.
95+ ignore_case: Ignores the case during step 1 matching
96+ ignore_punctuation: Remove the punctuation during step 1 matching
97+ regexes_to_ignore: Remove these regexes during step 1 matching
98+ """
99+ super ().__init__ (regex_pattern , group_select , fallback )
100+ self .ignore_case = ignore_case
101+ self .ignore_punctuation = ignore_punctuation
102+ self .regexes_to_ignore = regexes_to_ignore
46103
47104 def apply (self , resps , docs ):
48- def filter_set (inst ):
49- filtered_resp = []
50- for resp in inst :
51- if resp .startswith (" " ):
52- resp = resp [1 :]
105+ # here, we assume we have a list, in which each element is
106+ # a list of model responses for some particular input/target pair.
107+ # so we process each of these (same input/target response sets)
108+ # independently (and keep them a list.)
53109
54- filtered_resp .append (resp )
110+ def find_match (regex , resp , convert_dict = {}):
111+ match = regex .findall (resp )
112+ if match :
113+ match = match [self .group_select ]
114+ if isinstance (match , tuple ):
115+ match = [m for m in match if m ][0 ]
116+ match = match .strip ()
117+ if match and match in convert_dict :
118+ match = convert_dict [match ]
119+ return match
55120
56- return filtered_resp
121+ punct_tbl = dict . fromkeys ( i for i in range ( sys . maxunicode ) if unicodedata . category ( chr ( i )). startswith ( "P" ))
57122
58- filtered_resps = [filter_set (resp ) for resp in resps ]
123+ def filter_ignores (st ):
124+ if self .regexes_to_ignore is not None :
125+ for s in self .regexes_to_ignore :
126+ st = re .sub (s , "" , st )
127+
128+ if self .ignore_case :
129+ st = st .lower ()
130+
131+ if self .ignore_punctuation :
132+ # https://stackoverflow.com/a/266162
133+ st = st .translate (punct_tbl )
134+ return st
135+
136+ filtered_resps = []
137+
138+ for r , doc in zip (resps , docs ):
139+ fallback_regexes = []
140+ choice_to_alpha = {}
141+ next_alpha = "A"
142+
143+ without_paren_fallback_regexes = []
144+ without_paren_to_target = {}
145+
146+ choices = doc ["choices" ]
147+ for c in choices :
148+ m = filter_ignores (c .strip ())
149+ fallback_regexes .append (f"{ re .escape (m )} " )
150+ choice_to_alpha [m ] = f"({ next_alpha } )"
151+
152+ without_paren_fallback_regexes .append (next_alpha )
153+ without_paren_to_target [next_alpha ] = f"({ next_alpha } )"
154+
155+ next_alpha = chr (ord (next_alpha ) + 1 )
156+ fallback_regex = re .compile ("|" .join (fallback_regexes ))
157+ without_paren_fallback_regex = "|" .join (without_paren_fallback_regexes )
158+ without_paren_fallback_regex = re .compile (f":[\s]*({ without_paren_fallback_regex } )" )
159+
160+ filtered = []
161+ for resp in r :
162+ match = find_match (self .regex , resp )
163+ if not match :
164+ match = find_match (fallback_regex , filter_ignores (resp ), choice_to_alpha )
165+ if not match :
166+ match = find_match (without_paren_fallback_regex , resp , without_paren_to_target )
167+ if not match :
168+ match = self .fallback
169+ filtered .append (match )
170+ filtered_resps .append (filtered )
59171
60172 return filtered_resps
173+
174+
175+ class ExtendedRegexFilter (RegexFilter ):
176+ punct_tbl = dict .fromkeys (i for i in range (sys .maxunicode ) if unicodedata .category (chr (i )).startswith ("P" ))
177+
178+ def __init__ (
179+ self ,
180+ regex_pattern : str = r"#### (\-?[0-9\.\,]+)" ,
181+ group_select = 0 ,
182+ fallback : str = "[invalid]" ,
183+ ignore_case = False ,
184+ ignore_punctuation = False ,
185+ regexes_to_ignore = None ,
186+ ) -> None :
187+ super ().__init__ (regex_pattern , group_select , fallback )
188+ self .ignore_case = ignore_case
189+ self .ignore_punctuation = ignore_punctuation
190+ self .regexes_to_ignore = regexes_to_ignore
191+
192+ def filter_ignores (self , st ):
193+ if self .regexes_to_ignore is not None :
194+ for s in self .regexes_to_ignore :
195+ st = re .sub (s , "" , st )
196+
197+ if self .ignore_case :
198+ st = st .lower ()
199+
200+ if self .ignore_punctuation :
201+ # https://stackoverflow.com/a/266162
202+ st = st .translate (self .punct_tbl )
203+ return st
204+
205+ def find_match (self , regex , resp , convert_dict = {}):
206+ match = regex .findall (resp )
207+ if match :
208+ match = match [self .group_select ]
209+ if isinstance (match , tuple ):
210+ match = [m for m in match if m ][0 ]
211+ match = match .strip ()
212+ if match and match in convert_dict :
213+ match = convert_dict [match ]
214+ return match
0 commit comments