Skip to content

Commit 28a1f48

Browse files
authored
Merge pull request #1580 from google/google_sync
Google sync
2 parents 0dc0956 + 5333429 commit 28a1f48

27 files changed

Lines changed: 780 additions & 220 deletions

CHANGELOG

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
Version 2024.02.09:
2+
3+
Updates:
4+
* Remove 'deep' and 'store_all_calls' options.
5+
* Remove duplicate pytype inputs and outputs.
6+
7+
Bug fixes:
8+
* Fix module resolution bug in load_pytd.
9+
* Pattern matching:
10+
* Fix a corner case in pattern matching where the first case is None.
11+
* Fix a corner case when comparing to Any in a case statement.
12+
* Fix a false redundant-match when matching instances of a nonexhaustive type.
13+
* Do not attempt to track matching if we don't recognise a CMP as an instance.
14+
* Do not attempt to track matches if the match variable contains an Any.
15+
* Rework the check for an out-of-order opcode in a match block.
16+
* Fix a crash when calling get() on a TypedDict instance.
17+
* Don't crash when inferring a type for an uncalled attrs.define.
18+
* Handle aliased imports in type stubs better.
19+
* Teach pytype that zip is actually a class.
20+
* Catch bad external types in type annotations.
21+
122
Version 2024.01.24:
223

324
Updates:

pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ disable=
1515
arguments-differ,
1616
arguments-out-of-order,
1717
assigning-non-slot,
18+
assignment-from-no-return,
1819
attribute-defined-outside-init,
1920
bad-mcs-classmethod-argument,
2021
bad-option-value,

pytype/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# pylint: skip-file
2-
__version__ = '2024.01.24'
2+
__version__ = '2024.02.09'

pytype/blocks/process_blocks.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,3 @@ def adjust_returns(code, block_returns):
128128
new_line = next(lines, None)
129129
if new_line:
130130
op.line = new_line
131-
132-
133-
def check_out_of_order(code):
134-
"""Check if a line of code is executed out of order."""
135-
# This sometimes happens due to compiler optimisations, and needs to be
136-
# recorded so that we don't trigger code that is only meant to execute when
137-
# the main flow of control reaches a certain line.
138-
last_line = []
139-
for block in code.order:
140-
for op in block:
141-
if not last_line or last_line[-1].line == op.line:
142-
last_line.append(op)
143-
else:
144-
if op.line < last_line[-1].line:
145-
for x in last_line:
146-
x.metadata.is_out_of_order = True
147-
last_line = [op]

pytype/pattern_matching.py

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,20 @@ class _Matches:
252252
"""Tracks branches of match statements."""
253253

254254
def __init__(self, ast_matches):
255-
self.start_to_end = {}
255+
self.start_to_end = {} # match_line : match_end_line
256256
self.end_to_starts = collections.defaultdict(list)
257-
self.match_cases = {}
258-
self.defaults = set()
259-
self.as_names = {}
260-
self.matches = []
257+
self.match_cases = {} # opcode_line : match_line
258+
self.defaults = set() # lines with defaults
259+
self.as_names = {} # case_end_line : case_as_name
260+
self.unseen_cases = {} # match_line : num_unseen_cases
261261

262262
for m in ast_matches.matches:
263263
self._add_match(m.start, m.end, m.cases)
264264

265265
def _add_match(self, start, end, cases):
266266
self.start_to_end[start] = end
267267
self.end_to_starts[end].append(start)
268+
self.unseen_cases[start] = len(cases)
268269
for c in cases:
269270
for i in range(c.start, c.end + 1):
270271
self.match_cases[i] = start
@@ -273,6 +274,10 @@ def _add_match(self, start, end, cases):
273274
if c.as_name:
274275
self.as_names[c.end] = c.as_name
275276

277+
def register_case(self, match_line, case_line):
278+
assert self.match_cases[case_line] == match_line
279+
self.unseen_cases[match_line] -= 1
280+
276281
def __repr__(self):
277282
return f"""
278283
Matches: {sorted(self.start_to_end.items())}
@@ -301,10 +306,9 @@ def __init__(self, ast_matches, ctx):
301306
self.ctx = ctx
302307

303308
def _get_option_tracker(
304-
self, match_var: cfg.Variable, case_line: int
309+
self, match_var: cfg.Variable, match_line: int
305310
) -> _OptionTracker:
306311
"""Get the option tracker for a match line."""
307-
match_line = self.matches.match_cases[case_line]
308312
if (match_line not in self._option_tracker or
309313
match_var.id not in self._option_tracker[match_line]):
310314
self._option_tracker[match_line][match_var.id] = (
@@ -323,8 +327,16 @@ def _make_instance_for_match(self, node, types):
323327
ret.append(self.ctx.vm.init_class(node, cls))
324328
return self.ctx.join_variables(node, ret)
325329

330+
def _register_case_branch(self, op: opcodes.Opcode) -> Optional[int]:
331+
match_line = self.matches.match_cases.get(op.line)
332+
if match_line is None:
333+
return None
334+
self.matches.register_case(match_line, op.line)
335+
return match_line
336+
326337
def instantiate_case_var(self, op, match_var, node):
327-
tracker = self._get_option_tracker(match_var, op.line)
338+
match_line = self.matches.match_cases[op.line]
339+
tracker = self._get_option_tracker(match_var, match_line)
328340
if tracker.cases[op.line]:
329341
# We have matched on one or more classes in this case.
330342
types = [x.typ for x in tracker.cases[op.line]]
@@ -360,14 +372,16 @@ def register_match_type(self, op: opcodes.Opcode):
360372
self._match_types[match_line].add(_MatchTypes.make(op))
361373

362374
def add_none_branch(self, op: opcodes.Opcode, match_var: cfg.Variable):
363-
if op.line in self.matches.match_cases:
364-
tracker = self._get_option_tracker(match_var, op.line)
365-
tracker.cover_from_none(op.line)
366-
if not tracker.is_complete:
367-
return None
368-
else:
369-
# This is the last remaining case, and will always succeed.
370-
return True
375+
match_line = self._register_case_branch(op)
376+
if not match_line:
377+
return None
378+
tracker = self._get_option_tracker(match_var, match_line)
379+
tracker.cover_from_none(op.line)
380+
if not tracker.is_complete:
381+
return None
382+
else:
383+
# This is the last remaining case, and will always succeed.
384+
return True
371385

372386
def add_cmp_branch(
373387
self,
@@ -377,12 +391,13 @@ def add_cmp_branch(
377391
case_var: cfg.Variable
378392
) -> _MatchSuccessType:
379393
"""Add a compare-based match case branch to the tracker."""
380-
if cmp_type not in (slots.CMP_EQ, slots.CMP_IS):
394+
match_line = self._register_case_branch(op)
395+
if not match_line:
381396
return None
382397

383-
match_line = self.matches.match_cases.get(op.line)
384-
if not match_line:
398+
if cmp_type not in (slots.CMP_EQ, slots.CMP_IS):
385399
return None
400+
386401
match_type = self._match_types[match_line]
387402

388403
try:
@@ -403,7 +418,7 @@ def add_cmp_branch(
403418
# (enum or union of literals) that we are tracking.
404419
if not tracker:
405420
if _is_literal_match(match_var) or _is_enum_match(match_var, case_val):
406-
tracker = self._get_option_tracker(match_var, op.line)
421+
tracker = self._get_option_tracker(match_var, match_line)
407422

408423
# If none of the above apply we cannot do any sort of tracking.
409424
if not tracker:
@@ -425,32 +440,31 @@ def add_cmp_branch(
425440
def add_class_branch(self, op: opcodes.Opcode, match_var: cfg.Variable,
426441
case_var: cfg.Variable) -> _MatchSuccessType:
427442
"""Add a class-based match case branch to the tracker."""
428-
tracker = self._get_option_tracker(match_var, op.line)
443+
match_line = self._register_case_branch(op)
444+
if not match_line:
445+
return None
446+
tracker = self._get_option_tracker(match_var, match_line)
429447
tracker.cover(op.line, case_var)
430448
return tracker.is_complete or None
431449

432450
def add_default_branch(self, op: opcodes.Opcode) -> _MatchSuccessType:
433451
"""Add a default match case branch to the tracker."""
434-
match_line = self.matches.match_cases.get(op.line)
435-
if match_line is None:
436-
return None
437-
if match_line in self._option_tracker:
438-
for opt in self._option_tracker[match_line].values():
439-
# We no longer check for exhaustive or redundant matches once we hit a
440-
# default case.
441-
opt.invalidate()
442-
return True
443-
else:
452+
match_line = self._register_case_branch(op)
453+
if not match_line or match_line not in self._option_tracker:
444454
return None
445455

456+
for opt in self._option_tracker[match_line].values():
457+
# We no longer check for exhaustive or redundant matches once we hit a
458+
# default case.
459+
opt.invalidate()
460+
return True
461+
446462
def check_ending(
447463
self,
448464
op: opcodes.Opcode,
449465
implicit_return: bool = False
450466
) -> List[IncompleteMatch]:
451467
"""Check if we have ended a match statement with leftover cases."""
452-
if op.metadata.is_out_of_order:
453-
return []
454468
line = op.line
455469
if implicit_return:
456470
done = set()
@@ -464,6 +478,10 @@ def check_ending(
464478
ret = []
465479
for i in done:
466480
for start in self.matches.end_to_starts[i]:
481+
if self.matches.unseen_cases[start] > 0:
482+
# We have executed some opcode out of order and thus gone past the end
483+
# of the match block before seeing all case branches.
484+
continue
467485
trackers = self._option_tracker[start]
468486
for tracker in trackers.values():
469487
if tracker.is_valid:

pytype/pytd/visitors.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,10 @@ def VisitNamedType(self, t):
548548
if (isinstance(item, pytd.Constant) and
549549
item.name == "typing_extensions.TypedDict"):
550550
return self.to_type(pytd.NamedType("typing.TypedDict"))
551-
return self.to_type(item)
551+
try:
552+
return self.to_type(item)
553+
except NotImplementedError as e:
554+
raise SymbolLookupError(f"{item} is not a type") from e
552555

553556
def VisitClassType(self, t):
554557
new_type = self.VisitNamedType(t)
@@ -820,6 +823,12 @@ def VisitNamedType(self, node):
820823
resolved_node = self.to_type(self._LookupItemRecursive(node.name))
821824
except KeyError:
822825
resolved_node = node # lookup failures are handled later
826+
except NotImplementedError as e:
827+
# to_type() can raise NotImplementedError, but _LookupItemRecursive
828+
# shouldn't return a pytd node that can't be turned into a type in
829+
# this specific case. As such, it's impossible to test this case.
830+
# But it's irresponsible to just crash on it, so here we are.
831+
raise SymbolLookupError(f"{node.name} is not a type") from e
823832
else:
824833
if isinstance(resolved_node, pytd.ClassType):
825834
resolved_node.name = node.name

pytype/rewrite/CMakeLists.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ py_library(
1515
abstract
1616
SRCS
1717
abstract.py
18+
DEPS
19+
pytype.blocks.blocks
1820
)
1921

2022
py_test(
@@ -60,6 +62,7 @@ py_library(
6062
frame.py
6163
DEPS
6264
.abstract
65+
.stack
6366
pytype.blocks.blocks
6467
pytype.rewrite.flow.flow
6568
)
@@ -78,12 +81,34 @@ py_test(
7881
pytype.rewrite.tests.test_utils
7982
)
8083

84+
py_library(
85+
NAME
86+
stack
87+
SRCS
88+
stack.py
89+
DEPS
90+
.abstract
91+
pytype.rewrite.flow.flow
92+
)
93+
94+
py_test(
95+
NAME
96+
stack_test
97+
SRCS
98+
stack_test.py
99+
DEPS
100+
.abstract
101+
.stack
102+
pytype.rewrite.flow.flow
103+
)
104+
81105
py_library(
82106
NAME
83107
vm
84108
SRCS
85109
vm.py
86110
DEPS
111+
.abstract
87112
.frame
88113
pytype.blocks.blocks
89114
pytype.rewrite.flow.flow
@@ -95,6 +120,7 @@ py_test(
95120
SRCS
96121
vm_test.py
97122
DEPS
123+
.abstract
98124
.vm
99125
pytype.pyc.pyc
100126
pytype.rewrite.tests.test_utils

pytype/rewrite/abstract.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Abstract representations of Python values."""
22

3+
from pytype.blocks import blocks
4+
35

46
class BaseValue:
57
pass
@@ -15,3 +17,20 @@ def __repr__(self):
1517

1618
def __eq__(self, other):
1719
return type(self) == type(other) and self.constant == other.constant # pylint: disable=unidiomatic-typecheck
20+
21+
22+
class Function(BaseValue):
23+
24+
def __init__(self, name: str, code: blocks.OrderedCode):
25+
self.name = name
26+
self.code = code
27+
28+
def __repr__(self):
29+
return f'Function({self.name})'
30+
31+
32+
class _Null(BaseValue):
33+
pass
34+
35+
36+
NULL = _Null()

pytype/rewrite/analyze.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def check_types(
4444
init_maximum_depth: int = _INIT_MAXIMUM_DEPTH,
4545
maximum_depth: int = _MAXIMUM_DEPTH,
4646
) -> Analysis:
47-
"""Check types for the given source code."""
48-
_analyze(src, options, loader, init_maximum_depth, maximum_depth)
47+
"""Checks types for the given source code."""
48+
vm = _make_vm(src, options, loader, init_maximum_depth, maximum_depth)
49+
vm.analyze_all_defs()
4950
return Analysis(Context(), None, None)
5051

5152

@@ -56,28 +57,27 @@ def infer_types(
5657
init_maximum_depth: int = _INIT_MAXIMUM_DEPTH,
5758
maximum_depth: int = _MAXIMUM_DEPTH,
5859
) -> Analysis:
59-
"""Infer types for the given source code."""
60-
_analyze(src, options, loader, init_maximum_depth, maximum_depth)
60+
"""Infers types for the given source code."""
61+
vm = _make_vm(src, options, loader, init_maximum_depth, maximum_depth)
62+
vm.infer_stub()
6163
ast = pytd.TypeDeclUnit('inferred + unknowns', (), (), (), (), ())
6264
deps = pytd.TypeDeclUnit('<all>', (), (), (), (), ())
6365
return Analysis(Context(), ast, deps)
6466

6567

66-
def _analyze(
68+
def _make_vm(
6769
src: str,
6870
options: config.Options,
6971
loader: load_pytd.Loader,
7072
init_maximum_depth: int,
7173
maximum_depth: int,
72-
) -> None:
73-
"""Analyze the given source code."""
74+
) -> vm_lib.VirtualMachine:
75+
"""Creates abstract virtual machine for given source code."""
7476
del loader, init_maximum_depth, maximum_depth
7577
code = _get_bytecode(src, options)
7678
# TODO(b/241479600): Populate globals from builtins.
77-
globals_ = {}
78-
vm = vm_lib.VirtualMachine(code, globals_)
79-
vm.run()
80-
# TODO(b/241479600): Analyze classes and functions.
79+
initial_globals = {}
80+
return vm_lib.VirtualMachine(code, initial_globals)
8181

8282

8383
def _get_bytecode(src: str, options: config.Options) -> blocks.OrderedCode:

0 commit comments

Comments
 (0)