Skip to content

Commit cc5b862

Browse files
committed
rewrite: run function frames in vm.analyze_all_defs() and vm.infer_stub().
analyze_all_defs() runs all functions it can find. infer_stub() runs only the ones needed to infer module-level types. (We eventually need to handle things like a function creating a nested function and returning it, but we can figure that out later.) PiperOrigin-RevId: 605779001
1 parent 36bed9f commit cc5b862

2 files changed

Lines changed: 84 additions & 20 deletions

File tree

pytype/rewrite/vm.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,40 @@ def __init__(
1818
):
1919
self._code = code
2020
self._initial_globals = initial_globals
21+
self._module_frame: frame.Frame = None
2122

22-
def _run(self):
23-
module_frame = frame.Frame(
23+
def _run_module(self) -> None:
24+
assert not self._module_frame
25+
self._module_frame = frame.Frame(
2426
name='__main__',
2527
code=self._code,
2628
initial_locals=self._initial_globals,
2729
initial_globals=self._initial_globals,
2830
)
29-
module_frame.run()
30-
return module_frame
31+
self._module_frame.run()
32+
33+
def _run_function(self, func: abstract.Function) -> frame.Frame:
34+
assert self._module_frame
35+
func_frame = frame.Frame(
36+
name=func.name,
37+
code=func.code,
38+
initial_locals={},
39+
initial_globals=self._module_frame.final_locals,
40+
)
41+
func_frame.run()
42+
return func_frame
3143

3244
def analyze_all_defs(self):
33-
module_frame = self._run()
34-
for func in module_frame.functions:
35-
del func
36-
raise NotImplementedError('Function analysis not implemented yet')
45+
self._run_module()
46+
functions = list(self._module_frame.functions)
47+
while functions:
48+
func = functions.pop(0)
49+
func_frame = self._run_function(func)
50+
functions.extend(func_frame.functions)
3751

3852
def infer_stub(self):
39-
module_frame = self._run()
40-
for name, var in module_frame.final_locals:
41-
del name, var
42-
raise NotImplementedError('Pytd generation not implemented yet')
53+
self._run_module()
54+
for var in self._module_frame.final_locals.values():
55+
for value in var.values:
56+
if isinstance(value, abstract.Function):
57+
self._run_function(value)

pytype/rewrite/vm_test.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast
1+
from typing import Type, TypeVar
22

33
from pytype.pyc import opcodes
44
from pytype.rewrite import abstract
@@ -7,19 +7,28 @@
77

88
import unittest
99

10+
_T = TypeVar('_T')
11+
1012

1113
def _make_vm(src: str) -> vm_lib.VirtualMachine:
1214
return vm_lib.VirtualMachine(test_utils.parse(src), {})
1315

1416

17+
def _get(typ: Type[_T], var) -> _T:
18+
v = var.get_atomic_value()
19+
assert isinstance(v, typ)
20+
return v
21+
22+
1523
class VmTest(unittest.TestCase):
1624

1725
def test_run_module_frame(self):
1826
block = [opcodes.LOAD_CONST(0, 0, 0, None), opcodes.RETURN_VALUE(0, 0)]
1927
code = test_utils.FakeOrderedCode([block], [None])
2028
vm = vm_lib.VirtualMachine(code.Seal(), {})
21-
module_frame = vm._run()
22-
self.assertIsNotNone(module_frame)
29+
self.assertIsNone(vm._module_frame)
30+
vm._run_module()
31+
self.assertIsNotNone(vm._module_frame)
2332

2433
def test_globals(self):
2534
vm = _make_vm("""
@@ -33,18 +42,58 @@ def g():
3342
g()
3443
f()
3544
""")
36-
module_frame = vm._run()
45+
vm._run_module()
3746

3847
def get_const(var):
39-
return cast(abstract.PythonConstant, var.get_atomic_value()).constant
48+
return _get(abstract.PythonConstant, var).constant
4049

41-
x = get_const(module_frame.load_global('x'))
42-
y = get_const(module_frame.load_global('y'))
43-
z = get_const(module_frame.load_global('z'))
50+
x = get_const(vm._module_frame.load_global('x'))
51+
y = get_const(vm._module_frame.load_global('y'))
52+
z = get_const(vm._module_frame.load_global('z'))
4453
self.assertEqual(x, 42)
4554
self.assertIsNone(y)
4655
self.assertEqual(z, 42)
4756

57+
def test_analyze_functions(self):
58+
# Just make sure this doesn't crash.
59+
vm = _make_vm("""
60+
def f():
61+
def g():
62+
pass
63+
""")
64+
vm.analyze_all_defs()
65+
66+
def test_infer_stub(self):
67+
# Just make sure this doesn't crash.
68+
vm = _make_vm("""
69+
def f():
70+
def g():
71+
pass
72+
""")
73+
vm.infer_stub()
74+
75+
def test_run_function(self):
76+
vm = _make_vm("""
77+
x = None
78+
79+
def f():
80+
global x
81+
x = 42
82+
83+
def g():
84+
y = x
85+
""")
86+
vm._run_module()
87+
f = _get(abstract.Function, vm._module_frame.final_locals['f'])
88+
g = _get(abstract.Function, vm._module_frame.final_locals['g'])
89+
f_frame = vm._run_function(f)
90+
g_frame = vm._run_function(g)
91+
92+
self.assertEqual(f_frame.load_global('x').get_atomic_value(),
93+
abstract.PythonConstant(42))
94+
self.assertEqual(g_frame.load_local('y').get_atomic_value(),
95+
abstract.PythonConstant(None))
96+
4897

4998
if __name__ == '__main__':
5099
unittest.main()

0 commit comments

Comments
 (0)