@@ -252,19 +252,20 @@ class _Matches:
252252 """Tracks branches of match statements."""
253253
254254 def __init__ (self , ast_matches ):
255- self .start_to_end = {}
255+ self .start_to_end = {} # match_line : match_end_line
256256 self .end_to_starts = collections .defaultdict (list )
257- self .match_cases = {}
258- self .defaults = set ()
259- self .as_names = {}
260- self .matches = []
257+ self .match_cases = {} # opcode_line : match_line
258+ self .defaults = set () # lines with defaults
259+ self .as_names = {} # case_end_line : case_as_name
260+ self .unseen_cases = {} # match_line : num_unseen_cases
261261
262262 for m in ast_matches .matches :
263263 self ._add_match (m .start , m .end , m .cases )
264264
265265 def _add_match (self , start , end , cases ):
266266 self .start_to_end [start ] = end
267267 self .end_to_starts [end ].append (start )
268+ self .unseen_cases [start ] = len (cases )
268269 for c in cases :
269270 for i in range (c .start , c .end + 1 ):
270271 self .match_cases [i ] = start
@@ -273,6 +274,10 @@ def _add_match(self, start, end, cases):
273274 if c .as_name :
274275 self .as_names [c .end ] = c .as_name
275276
277+ def register_case (self , match_line , case_line ):
278+ assert self .match_cases [case_line ] == match_line
279+ self .unseen_cases [match_line ] -= 1
280+
276281 def __repr__ (self ):
277282 return f"""
278283 Matches: { sorted (self .start_to_end .items ())}
@@ -301,10 +306,9 @@ def __init__(self, ast_matches, ctx):
301306 self .ctx = ctx
302307
303308 def _get_option_tracker (
304- self , match_var : cfg .Variable , case_line : int
309+ self , match_var : cfg .Variable , match_line : int
305310 ) -> _OptionTracker :
306311 """Get the option tracker for a match line."""
307- match_line = self .matches .match_cases [case_line ]
308312 if (match_line not in self ._option_tracker or
309313 match_var .id not in self ._option_tracker [match_line ]):
310314 self ._option_tracker [match_line ][match_var .id ] = (
@@ -323,8 +327,16 @@ def _make_instance_for_match(self, node, types):
323327 ret .append (self .ctx .vm .init_class (node , cls ))
324328 return self .ctx .join_variables (node , ret )
325329
330+ def _register_case_branch (self , op : opcodes .Opcode ) -> Optional [int ]:
331+ match_line = self .matches .match_cases .get (op .line )
332+ if match_line is None :
333+ return None
334+ self .matches .register_case (match_line , op .line )
335+ return match_line
336+
326337 def instantiate_case_var (self , op , match_var , node ):
327- tracker = self ._get_option_tracker (match_var , op .line )
338+ match_line = self .matches .match_cases [op .line ]
339+ tracker = self ._get_option_tracker (match_var , match_line )
328340 if tracker .cases [op .line ]:
329341 # We have matched on one or more classes in this case.
330342 types = [x .typ for x in tracker .cases [op .line ]]
@@ -360,14 +372,16 @@ def register_match_type(self, op: opcodes.Opcode):
360372 self ._match_types [match_line ].add (_MatchTypes .make (op ))
361373
362374 def add_none_branch (self , op : opcodes .Opcode , match_var : cfg .Variable ):
363- if op .line in self .matches .match_cases :
364- tracker = self ._get_option_tracker (match_var , op .line )
365- tracker .cover_from_none (op .line )
366- if not tracker .is_complete :
367- return None
368- else :
369- # This is the last remaining case, and will always succeed.
370- return True
375+ match_line = self ._register_case_branch (op )
376+ if not match_line :
377+ return None
378+ tracker = self ._get_option_tracker (match_var , match_line )
379+ tracker .cover_from_none (op .line )
380+ if not tracker .is_complete :
381+ return None
382+ else :
383+ # This is the last remaining case, and will always succeed.
384+ return True
371385
372386 def add_cmp_branch (
373387 self ,
@@ -377,12 +391,13 @@ def add_cmp_branch(
377391 case_var : cfg .Variable
378392 ) -> _MatchSuccessType :
379393 """Add a compare-based match case branch to the tracker."""
380- if cmp_type not in (slots .CMP_EQ , slots .CMP_IS ):
394+ match_line = self ._register_case_branch (op )
395+ if not match_line :
381396 return None
382397
383- match_line = self .matches .match_cases .get (op .line )
384- if not match_line :
398+ if cmp_type not in (slots .CMP_EQ , slots .CMP_IS ):
385399 return None
400+
386401 match_type = self ._match_types [match_line ]
387402
388403 try :
@@ -403,7 +418,7 @@ def add_cmp_branch(
403418 # (enum or union of literals) that we are tracking.
404419 if not tracker :
405420 if _is_literal_match (match_var ) or _is_enum_match (match_var , case_val ):
406- tracker = self ._get_option_tracker (match_var , op . line )
421+ tracker = self ._get_option_tracker (match_var , match_line )
407422
408423 # If none of the above apply we cannot do any sort of tracking.
409424 if not tracker :
@@ -425,32 +440,31 @@ def add_cmp_branch(
425440 def add_class_branch (self , op : opcodes .Opcode , match_var : cfg .Variable ,
426441 case_var : cfg .Variable ) -> _MatchSuccessType :
427442 """Add a class-based match case branch to the tracker."""
428- tracker = self ._get_option_tracker (match_var , op .line )
443+ match_line = self ._register_case_branch (op )
444+ if not match_line :
445+ return None
446+ tracker = self ._get_option_tracker (match_var , match_line )
429447 tracker .cover (op .line , case_var )
430448 return tracker .is_complete or None
431449
432450 def add_default_branch (self , op : opcodes .Opcode ) -> _MatchSuccessType :
433451 """Add a default match case branch to the tracker."""
434- match_line = self .matches .match_cases .get (op .line )
435- if match_line is None :
436- return None
437- if match_line in self ._option_tracker :
438- for opt in self ._option_tracker [match_line ].values ():
439- # We no longer check for exhaustive or redundant matches once we hit a
440- # default case.
441- opt .invalidate ()
442- return True
443- else :
452+ match_line = self ._register_case_branch (op )
453+ if not match_line or match_line not in self ._option_tracker :
444454 return None
445455
456+ for opt in self ._option_tracker [match_line ].values ():
457+ # We no longer check for exhaustive or redundant matches once we hit a
458+ # default case.
459+ opt .invalidate ()
460+ return True
461+
446462 def check_ending (
447463 self ,
448464 op : opcodes .Opcode ,
449465 implicit_return : bool = False
450466 ) -> List [IncompleteMatch ]:
451467 """Check if we have ended a match statement with leftover cases."""
452- if op .metadata .is_out_of_order :
453- return []
454468 line = op .line
455469 if implicit_return :
456470 done = set ()
@@ -464,6 +478,10 @@ def check_ending(
464478 ret = []
465479 for i in done :
466480 for start in self .matches .end_to_starts [i ]:
481+ if self .matches .unseen_cases [start ] > 0 :
482+ # We have executed some opcode out of order and thus gone past the end
483+ # of the match block before seeing all case branches.
484+ continue
467485 trackers = self ._option_tracker [start ]
468486 for tracker in trackers .values ():
469487 if tracker .is_valid :
0 commit comments