Skip to content

Commit 889ad53

Browse files
committed
Fixed checking of variable assignments involving tuple unpacking
This also unified all variable checking across different assignment types (annotation assignment, augmented assignment and any other kind of assignment) Fixes #486.
1 parent 9a73eb0 commit 889ad53

5 files changed

Lines changed: 204 additions & 139 deletions

File tree

docs/versionhistory.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ This library adheres to
99
- Dropped Python 3.8 support
1010
- Changed the signature of ``typeguard_ignore()`` to be compatible with
1111
``typing.no_type_check()`` (PR by @jolaf)
12+
- Fixed checking of variable assignments involving tuple unpacking
13+
(`#486 <https://github.com/agronholm/typeguard/pull/486>`_)
1214

1315
**4.4.0** (2024-10-27)
1416

src/typeguard/_functions.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
import warnings
5+
from collections.abc import Sequence
56
from typing import Any, Callable, NoReturn, TypeVar, Union, overload
67

78
from . import _suppression
@@ -242,59 +243,53 @@ def check_yield_type(
242243

243244

244245
def check_variable_assignment(
245-
value: object, varname: str, annotation: Any, memo: TypeCheckMemo
246+
value: Any, targets: Sequence[list[tuple[str, Any]]], memo: TypeCheckMemo
246247
) -> Any:
247248
if _suppression.type_checks_suppressed:
248249
return value
249250

250-
try:
251-
check_type_internal(value, annotation, memo)
252-
except TypeCheckError as exc:
253-
qualname = qualified_name(value, add_class_prefix=True)
254-
exc.append_path_element(f"value assigned to {varname} ({qualname})")
255-
if memo.config.typecheck_fail_callback:
256-
memo.config.typecheck_fail_callback(exc, memo)
257-
else:
258-
raise
259-
260-
return value
261-
251+
value_to_return = value
252+
for target in targets:
253+
star_variable_index = next(
254+
(i for i, (varname, _) in enumerate(target) if varname.startswith("*")),
255+
None,
256+
)
257+
if star_variable_index is not None:
258+
value_to_return = list(value)
259+
remaining_vars = len(target) - 1 - star_variable_index
260+
end_index = len(value_to_return) - remaining_vars
261+
values_to_check = (
262+
value_to_return[:star_variable_index]
263+
+ [value_to_return[star_variable_index:end_index]]
264+
+ value_to_return[end_index:]
265+
)
266+
elif len(target) > 1:
267+
values_to_check = value_to_return = []
268+
iterator = iter(value)
269+
for _ in target:
270+
try:
271+
values_to_check.append(next(iterator))
272+
except StopIteration:
273+
raise ValueError(
274+
f"not enough values to unpack (expected {len(target)}, got "
275+
f"{len(values_to_check)})"
276+
) from None
262277

263-
def check_multi_variable_assignment(
264-
value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo
265-
) -> Any:
266-
if max(len(target) for target in targets) == 1:
267-
iterated_values = [value]
268-
else:
269-
iterated_values = list(value)
270-
271-
if not _suppression.type_checks_suppressed:
272-
for expected_types in targets:
273-
value_index = 0
274-
for ann_index, (varname, expected_type) in enumerate(
275-
expected_types.items()
276-
):
277-
if varname.startswith("*"):
278-
varname = varname[1:]
279-
keys_left = len(expected_types) - 1 - ann_index
280-
next_value_index = len(iterated_values) - keys_left
281-
obj: object = iterated_values[value_index:next_value_index]
282-
value_index = next_value_index
278+
else:
279+
values_to_check = [value]
280+
281+
for val, (varname, annotation) in zip(values_to_check, target):
282+
try:
283+
check_type_internal(val, annotation, memo)
284+
except TypeCheckError as exc:
285+
qualname = qualified_name(val, add_class_prefix=True)
286+
exc.append_path_element(f"value assigned to {varname} ({qualname})")
287+
if memo.config.typecheck_fail_callback:
288+
memo.config.typecheck_fail_callback(exc, memo)
283289
else:
284-
obj = iterated_values[value_index]
285-
value_index += 1
290+
raise
286291

287-
try:
288-
check_type_internal(obj, expected_type, memo)
289-
except TypeCheckError as exc:
290-
qualname = qualified_name(obj, add_class_prefix=True)
291-
exc.append_path_element(f"value assigned to {varname} ({qualname})")
292-
if memo.config.typecheck_fail_callback:
293-
memo.config.typecheck_fail_callback(exc, memo)
294-
else:
295-
raise
296-
297-
return iterated_values[0] if len(iterated_values) == 1 else iterated_values
292+
return value_to_return
298293

299294

300295
def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None:

src/typeguard/_transformer.py

Lines changed: 75 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
If,
2929
Import,
3030
ImportFrom,
31-
Index,
3231
List,
3332
Load,
3433
LShift,
@@ -389,9 +388,7 @@ def visit_BinOp(self, node: BinOp) -> Any:
389388
union_name = self.transformer._get_import("typing", "Union")
390389
return Subscript(
391390
value=union_name,
392-
slice=Index(
393-
Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
394-
),
391+
slice=Tuple(elts=[node.left, node.right], ctx=Load()),
395392
ctx=Load(),
396393
)
397394

@@ -410,24 +407,18 @@ def visit_Subscript(self, node: Subscript) -> Any:
410407
# The subscript of typing(_extensions).Literal can be any arbitrary string, so
411408
# don't try to evaluate it as code
412409
if node.slice:
413-
if isinstance(node.slice, Index):
414-
# Python 3.8
415-
slice_value = node.slice.value # type: ignore[attr-defined]
416-
else:
417-
slice_value = node.slice
418-
419-
if isinstance(slice_value, Tuple):
410+
if isinstance(node.slice, Tuple):
420411
if self._memo.name_matches(node.value, *annotated_names):
421412
# Only treat the first argument to typing.Annotated as a potential
422413
# forward reference
423414
items = cast(
424415
typing.List[expr],
425-
[self.visit(slice_value.elts[0])] + slice_value.elts[1:],
416+
[self.visit(node.slice.elts[0])] + node.slice.elts[1:],
426417
)
427418
else:
428419
items = cast(
429420
typing.List[expr],
430-
[self.visit(item) for item in slice_value.elts],
421+
[self.visit(item) for item in node.slice.elts],
431422
)
432423

433424
# If this is a Union and any of the items is Any, erase the entire
@@ -450,7 +441,7 @@ def visit_Subscript(self, node: Subscript) -> Any:
450441
if item is None:
451442
items[index] = self.transformer._get_import("typing", "Any")
452443

453-
slice_value.elts = items
444+
node.slice.elts = items
454445
else:
455446
self.generic_visit(node)
456447

@@ -542,18 +533,10 @@ def _use_memo(
542533
return_annotation, *generator_names
543534
):
544535
if isinstance(return_annotation, Subscript):
545-
annotation_slice = return_annotation.slice
546-
547-
# Python < 3.9
548-
if isinstance(annotation_slice, Index):
549-
annotation_slice = (
550-
annotation_slice.value # type: ignore[attr-defined]
551-
)
552-
553-
if isinstance(annotation_slice, Tuple):
554-
items = annotation_slice.elts
536+
if isinstance(return_annotation.slice, Tuple):
537+
items = return_annotation.slice.elts
555538
else:
556-
items = [annotation_slice]
539+
items = [return_annotation.slice]
557540

558541
if len(items) > 0:
559542
new_memo.yield_annotation = self._convert_annotation(
@@ -743,7 +726,7 @@ def visit_FunctionDef(
743726
annotation_ = self._convert_annotation(node.args.vararg.annotation)
744727
if annotation_:
745728
container = Name("tuple", ctx=Load())
746-
subscript_slice: Tuple | Index = Tuple(
729+
subscript_slice = Tuple(
747730
[
748731
annotation_,
749732
Constant(Ellipsis),
@@ -1024,12 +1007,25 @@ def visit_AnnAssign(self, node: AnnAssign) -> Any:
10241007
func_name = self._get_import(
10251008
"typeguard._functions", "check_variable_assignment"
10261009
)
1010+
targets_arg = List(
1011+
[
1012+
List(
1013+
[
1014+
Tuple(
1015+
[Constant(node.target.id), annotation],
1016+
ctx=Load(),
1017+
)
1018+
],
1019+
ctx=Load(),
1020+
)
1021+
],
1022+
ctx=Load(),
1023+
)
10271024
node.value = Call(
10281025
func_name,
10291026
[
10301027
node.value,
1031-
Constant(node.target.id),
1032-
annotation,
1028+
targets_arg,
10331029
self._memo.get_memo_name(),
10341030
],
10351031
[],
@@ -1047,7 +1043,7 @@ def visit_Assign(self, node: Assign) -> Any:
10471043

10481044
# Only instrument function-local assignments
10491045
if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
1050-
targets: list[dict[Constant, expr | None]] = []
1046+
preliminary_targets: list[list[tuple[Constant, expr | None]]] = []
10511047
check_required = False
10521048
for target in node.targets:
10531049
elts: Sequence[expr]
@@ -1058,63 +1054,63 @@ def visit_Assign(self, node: Assign) -> Any:
10581054
else:
10591055
continue
10601056

1061-
annotations_: dict[Constant, expr | None] = {}
1057+
annotations_: list[tuple[Constant, expr | None]] = []
10621058
for exp in elts:
10631059
prefix = ""
10641060
if isinstance(exp, Starred):
10651061
exp = exp.value
10661062
prefix = "*"
10671063

1064+
path: list[str] = []
1065+
while isinstance(exp, Attribute):
1066+
path.insert(0, exp.attr)
1067+
exp = exp.value
1068+
10681069
if isinstance(exp, Name):
1069-
self._memo.ignored_names.add(exp.id)
1070-
name = prefix + exp.id
1070+
if not path:
1071+
self._memo.ignored_names.add(exp.id)
1072+
1073+
path.insert(0, exp.id)
1074+
name = prefix + ".".join(path)
10711075
annotation = self._memo.variable_annotations.get(exp.id)
10721076
if annotation:
1073-
annotations_[Constant(name)] = annotation
1077+
annotations_.append((Constant(name), annotation))
10741078
check_required = True
10751079
else:
1076-
annotations_[Constant(name)] = None
1080+
annotations_.append((Constant(name), None))
10771081

1078-
targets.append(annotations_)
1082+
preliminary_targets.append(annotations_)
10791083

10801084
if check_required:
10811085
# Replace missing annotations with typing.Any
1082-
for item in targets:
1083-
for key, expression in item.items():
1086+
targets: list[list[tuple[Constant, expr]]] = []
1087+
for items in preliminary_targets:
1088+
target_list: list[tuple[Constant, expr]] = []
1089+
targets.append(target_list)
1090+
for key, expression in items:
10841091
if expression is None:
1085-
item[key] = self._get_import("typing", "Any")
1092+
target_list.append((key, self._get_import("typing", "Any")))
1093+
else:
1094+
target_list.append((key, expression))
10861095

1087-
if len(targets) == 1 and len(targets[0]) == 1:
1088-
func_name = self._get_import(
1089-
"typeguard._functions", "check_variable_assignment"
1090-
)
1091-
target_varname = next(iter(targets[0]))
1092-
node.value = Call(
1093-
func_name,
1094-
[
1095-
node.value,
1096-
target_varname,
1097-
targets[0][target_varname],
1098-
self._memo.get_memo_name(),
1099-
],
1100-
[],
1101-
)
1102-
elif targets:
1103-
func_name = self._get_import(
1104-
"typeguard._functions", "check_multi_variable_assignment"
1105-
)
1106-
targets_arg = List(
1107-
[
1108-
Dict(keys=list(target), values=list(target.values()))
1109-
for target in targets
1110-
],
1111-
ctx=Load(),
1112-
)
1113-
node.value = Call(
1114-
func_name,
1115-
[node.value, targets_arg, self._memo.get_memo_name()],
1116-
[],
1117-
)
1096+
func_name = self._get_import(
1097+
"typeguard._functions", "check_variable_assignment"
1098+
)
1099+
targets_arg = List(
1100+
[
1101+
List(
1102+
[Tuple([name, ann], ctx=Load()) for name, ann in target],
1103+
ctx=Load(),
1104+
)
1105+
for target in targets
1106+
],
1107+
ctx=Load(),
1108+
)
1109+
node.value = Call(
1110+
func_name,
1111+
[node.value, targets_arg, self._memo.get_memo_name()],
1112+
[],
1113+
)
11181114

11191115
return node
11201116

@@ -1175,12 +1171,20 @@ def visit_AugAssign(self, node: AugAssign) -> Any:
11751171
operator_call = Call(
11761172
operator_func, [Name(node.target.id, ctx=Load()), node.value], []
11771173
)
1174+
targets_arg = List(
1175+
[
1176+
List(
1177+
[Tuple([Constant(node.target.id), annotation], ctx=Load())],
1178+
ctx=Load(),
1179+
)
1180+
],
1181+
ctx=Load(),
1182+
)
11781183
check_call = Call(
11791184
self._get_import("typeguard._functions", "check_variable_assignment"),
11801185
[
11811186
operator_call,
1182-
Constant(node.target.id),
1183-
annotation,
1187+
targets_arg,
11841188
self._memo.get_memo_name(),
11851189
],
11861190
[],

src/typeguard/_union_transformer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
from ast import (
99
BinOp,
1010
BitOr,
11-
Index,
1211
Load,
1312
Name,
1413
NodeTransformer,
1514
Subscript,
15+
Tuple,
1616
fix_missing_locations,
1717
parse,
1818
)
19-
from ast import Tuple as ASTTuple
2019
from types import CodeType
2120
from typing import Any
2221

@@ -30,9 +29,7 @@ def visit_BinOp(self, node: BinOp) -> Any:
3029
if isinstance(node.op, BitOr):
3130
return Subscript(
3231
value=self.union_name,
33-
slice=Index(
34-
ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
35-
),
32+
slice=Tuple(elts=[node.left, node.right], ctx=Load()),
3633
ctx=Load(),
3734
)
3835

0 commit comments

Comments
 (0)