Skip to content

Commit a54d66a

Browse files
martindemellorchen152
authored andcommitted
rewrite: Support unpacking of concrete iterables and dicts in function calls.
PiperOrigin-RevId: 623931080
1 parent 7d4ce24 commit a54d66a

7 files changed

Lines changed: 78 additions & 8 deletions

File tree

pytype/rewrite/abstract/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ py_library(
6868
containers.py
6969
DEPS
7070
.base
71+
.internal
72+
.utils
7173
)
7274

7375
py_test(

pytype/rewrite/abstract/containers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Dict as _Dict, List as _List, Set as _Set, Tuple as _Tuple
66

77
from pytype.rewrite.abstract import base
8+
from pytype.rewrite.abstract import internal
9+
from pytype.rewrite.abstract import utils
810

911
log = logging.getLogger(__name__)
1012

@@ -25,6 +27,19 @@ def __repr__(self):
2527
def append(self, val: _Variable):
2628
self.constant.append(val)
2729

30+
def extend(self, val: _Variable):
31+
try:
32+
const = utils.get_atomic_constant(val)
33+
if not isinstance(const, list):
34+
const = None
35+
except ValueError:
36+
const = None
37+
38+
if const:
39+
self.constant.extend(const)
40+
else:
41+
self.constant.append(internal.Splat(self._ctx, val).to_variable())
42+
2843

2944
class Dict(base.PythonConstant[_Dict[_Variable, _Variable]]):
3045
"""Representation of a Python dict."""
@@ -34,13 +49,27 @@ def __init__(
3449
):
3550
assert isinstance(constant, dict), constant
3651
super().__init__(ctx, constant)
52+
self.indefinite = False
3753

3854
def __repr__(self):
3955
return f'Dict({self.constant!r})'
4056

4157
def setitem(self, key, val):
4258
self.constant[key] = val
4359

60+
def update(self, val: _Variable):
61+
try:
62+
const = utils.get_atomic_constant(val)
63+
if not isinstance(const, dict):
64+
const = None
65+
except ValueError:
66+
const = None
67+
68+
if const:
69+
self.constant.update(const)
70+
else:
71+
self.indefinite = True
72+
4473

4574
class Set(base.PythonConstant[_Set[_Variable]]):
4675
"""Representation of a Python set."""

pytype/rewrite/abstract/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class Args(Generic[_FrameT]):
5757
"""Arguments to one function call."""
5858
posargs: Tuple[base.AbstractVariableType, ...] = ()
5959
kwargs: Mapping[str, base.AbstractVariableType] = _EMPTY_MAP
60+
starargs: Optional[base.AbstractVariableType] = None
61+
starstarargs: Optional[base.AbstractVariableType] = None
6062
frame: Optional[_FrameT] = None
6163

6264

pytype/rewrite/abstract/internal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Abstract types used internally by pytype."""
22

3-
from typing import Dict, Sequence, Tuple
3+
from typing import Dict, Tuple
44

55
import immutabledict
66

@@ -54,9 +54,9 @@ class Splat(base.BaseValue):
5454
(x, *ys, z) in starargs) and let the function arg matcher unpack them.
5555
"""
5656

57-
def __init__(self, ctx: base.ContextType, iterable: Sequence[_Variable]):
57+
def __init__(self, ctx: base.ContextType, iterable: _Variable):
5858
super().__init__(ctx)
59-
self.iterable = tuple(iterable)
59+
self.iterable = iterable
6060

6161
def __repr__(self):
6262
return f"splat({self.iterable!r})"

pytype/rewrite/abstract/internal_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class SplatTest(test_utils.ContextfulTestBase):
2121

2222
def test_basic(self):
2323
# Basic smoke test, remove when we have some real functionality to test.
24-
seq = [self.ctx.consts[i].to_variable() for i in range(3)]
24+
cls = self.ctx.abstract_loader.load_raw_type(tuple)
25+
seq = cls.instantiate().to_variable()
2526
x = internal.Splat(self.ctx, seq)
26-
self.assertEqual(x.iterable, tuple(seq))
27+
self.assertEqual(x.iterable, seq)
2728

2829

2930
if __name__ == '__main__':

pytype/rewrite/frame.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ def byte_CALL_FUNCTION_EX(self, opcode):
610610
posargs = self._unpack_starargs(starargs).constant
611611
func = self._stack.pop()
612612
if self._code.python_version >= (3, 11):
613+
# the compiler puts a NULL on the stack before function calls
613614
self._stack.pop_and_discard()
614615
callargs = abstract.Args(posargs=posargs, kwargs=kwargs, frame=self)
615616
self._call_function(func, callargs)
@@ -707,6 +708,30 @@ def byte_MAP_ADD(self, opcode):
707708
target = target_var.get_atomic_value()
708709
target.setitem(key, val)
709710

711+
def byte_LIST_EXTEND(self, opcode):
712+
count = opcode.arg
713+
val = self._stack.pop()
714+
target_var = self._stack.peek(count)
715+
target = target_var.get_atomic_value()
716+
target.extend(val)
717+
718+
def byte_DICT_MERGE(self, opcode):
719+
# DICT_MERGE is like DICT_UPDATE but raises an exception for duplicate keys.
720+
return self.byte_DICT_UPDATE(opcode)
721+
722+
def byte_DICT_UPDATE(self, opcode):
723+
count = opcode.arg
724+
val = self._stack.pop()
725+
target_var = self._stack.peek(count)
726+
target = target_var.get_atomic_value()
727+
target.update(val)
728+
729+
def byte_LIST_TO_TUPLE(self, opcode):
730+
target_var = self._stack.pop()
731+
target = abstract.get_atomic_constant(target_var, list)
732+
ret = abstract.Tuple(self._ctx, tuple(target)).to_variable()
733+
self._stack.push(ret)
734+
710735
# ---------------------------------------------------------------
711736
# Branches and jumps
712737

pytype/rewrite/frame_test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Mapping, Type, TypeVar, cast, get_origin
22

3+
from unittest import mock
4+
35
from pytype.pyc import opcodes
46
from pytype.rewrite import frame as frame_lib
57
from pytype.rewrite.abstract import abstract
@@ -636,14 +638,23 @@ def f(x, *, y):
636638
self.assertConstantVar(callargs.kwargs['y'], 2)
637639

638640
@test_utils.skipBeforePy((3, 11), 'Relies on 3.11+ bytecode')
639-
def test_call_function_ex_no_crash(self):
641+
def test_call_function_ex_callargs(self):
642+
"""Test unpacking of concrete *args and **args."""
640643
frame = self._make_frame("""
641644
def f(x, y, z):
642645
pass
643646
a = (1, 2)
644-
f(*a, z=3)
647+
kw = {'z': 3}
648+
f(*a, **kw)
645649
""")
646-
frame.run()
650+
with mock.patch.object(
651+
frame_lib.Frame, '_call_function', wraps=frame._call_function
652+
) as mock_call:
653+
frame.run()
654+
(_, callargs), _ = mock_call.call_args_list[0]
655+
self.assertConstantVar(callargs.posargs[0], 1)
656+
self.assertConstantVar(callargs.posargs[1], 2)
657+
self.assertConstantVar(callargs.kwargs['z'], 3)
647658

648659

649660
if __name__ == '__main__':

0 commit comments

Comments
 (0)