Skip to content

Commit bb0736a

Browse files
committed
Report an error when a TypedDict key has an ambiguous type.
* Reports an [invalid-annotation] error when a TypedDict key has an ambiguous type. * Simplifies some of the TypedDict code by assuming that keys have unambiguous types. Fixes #1566. PiperOrigin-RevId: 599320573
1 parent c699398 commit bb0736a

4 files changed

Lines changed: 49 additions & 36 deletions

File tree

pytype/matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def _match_dict_against_typed_dict(
14661466
for k, v in left.pyval.items():
14671467
if k not in fields:
14681468
continue
1469-
typ = abstract_utils.get_atomic_value(fields[k])
1469+
typ = fields[k]
14701470
match_result = self.compute_one_match(v, typ)
14711471
if not match_result.success:
14721472
bad.append((k, match_result.bad_matches))

pytype/output.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,8 @@ def _typed_dict_to_def(self, node, v, name):
10161016
keywords.append(("total", pytd.Literal(False)))
10171017
bases = (pytd.NamedType("typing.TypedDict"),)
10181018
constants = []
1019-
for k, var in v.props.fields.items():
1020-
typ = pytd_utils.JoinTypes(
1021-
self.value_instance_to_pytd_type(node, p, None, set(), {})
1022-
for p in var.data)
1019+
for k, val in v.props.fields.items():
1020+
typ = self.value_instance_to_pytd_type(node, val, None, set(), {})
10231021
if v.props.total and k not in v.props.required:
10241022
typ = pytd.GenericType(pytd.NamedType("typing.NotRequired"), (typ,))
10251023
elif not v.props.total and k in v.props.required:

pytype/overlays/typed_dict.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44

5-
from typing import Any, Dict, Optional, Set
5+
from typing import Dict, Optional, Set
66

77
from pytype.abstract import abstract
88
from pytype.abstract import abstract_utils
@@ -34,7 +34,7 @@ class TypedDictProperties:
3434
"""Collection of typed dict properties passed between various stages."""
3535

3636
name: str
37-
fields: Dict[str, Any]
37+
fields: Dict[str, abstract.BaseValue]
3838
required: Set[str]
3939
total: bool
4040

@@ -48,27 +48,15 @@ def optional(self):
4848

4949
def add(self, k, v, total):
5050
"""Adds key and value."""
51-
values = []
52-
all_requiredness = set()
53-
for value in v.data:
54-
req = _is_required(value)
55-
if req is None:
56-
values.append(value)
57-
all_requiredness.add(None)
58-
elif isinstance(value, abstract.ParameterizedClass):
59-
values.append(value.formal_type_parameters[abstract_utils.T])
60-
all_requiredness.add(req)
61-
else:
62-
values.append(value.ctx.convert.unsolvable)
63-
all_requiredness.add(req)
64-
if (len(all_requiredness) == 1 and
65-
(requiredness := next(iter(all_requiredness))) is not None):
66-
final_v = v.program.NewVariable(values, [], v.program.entrypoint)
67-
required = requiredness
51+
req = _is_required(v)
52+
if req is None:
53+
value = v
54+
elif isinstance(v, abstract.ParameterizedClass):
55+
value = v.formal_type_parameters[abstract_utils.T]
6856
else:
69-
final_v = v
70-
required = total
71-
self.fields[k] = final_v # pylint: disable=unsupported-assignment-operation
57+
value = v.ctx.convert.unsolvable
58+
required = total if req is None else req
59+
self.fields[k] = value # pylint: disable=unsupported-assignment-operation
7260
if required:
7361
self.required.add(k)
7462

@@ -122,7 +110,12 @@ def _extract_args(self, args):
122110
name=name, fields={}, required=set(), total=total)
123111
# Force Required/NotRequired evaluation
124112
for k, v in fields.items():
125-
props.add(k, v, total)
113+
try:
114+
value = abstract_utils.get_atomic_value(v)
115+
except abstract_utils.ConversionError:
116+
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, v.data, k)
117+
value = self.ctx.convert.unsolvable
118+
props.add(k, value, total)
126119
return props
127120

128121
def _validate_bases(self, cls_name, bases):
@@ -182,8 +175,14 @@ def make_class(self, node, bases, f_locals, total):
182175
ordering=classgen.Ordering.FIRST_ANNOTATE,
183176
ctx=self.ctx)
184177
for k, local in cls_locals.items():
185-
assert local.typ
186-
props.add(k, local.typ, total)
178+
var = local.typ
179+
assert var
180+
try:
181+
typ = abstract_utils.get_atomic_value(var)
182+
except abstract_utils.ConversionError:
183+
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, var.data, k)
184+
typ = self.ctx.convert.unsolvable
185+
props.add(k, typ, total)
187186

188187
# Process base classes and generate the __init__ signature.
189188
self._validate_bases(cls_name, bases)
@@ -207,7 +206,7 @@ def make_class_from_pyi(self, cls_name, pytd_cls):
207206
name=name, fields={}, required=set(), total=total)
208207

209208
for c in pytd_cls.constants:
210-
typ = self.ctx.convert.constant_to_var(c.type)
209+
typ = self.ctx.convert.constant_to_value(c.type)
211210
props.add(c.name, typ, total)
212211

213212
# Process base classes and generate the __init__ signature.
@@ -239,8 +238,7 @@ def _make_init(self, props):
239238
sig = function.Signature.from_param_names(
240239
f"{props.name}.__init__", props.fields.keys(),
241240
kind=pytd.ParameterKind.KWONLY)
242-
sig.annotations = {k: abstract_utils.get_atomic_value(v)
243-
for k, v in props.fields.items()}
241+
sig.annotations = dict(props.fields)
244242
sig.defaults = {k: self.ctx.new_unsolvable(self.ctx.root_node)
245243
for k in props.optional}
246244
return abstract.SimpleFunction(sig, self.ctx)
@@ -256,8 +254,7 @@ def _new_instance(self, container, node, args):
256254
def instantiate_value(self, node, container):
257255
args = function.Args(())
258256
for name, typ in self.props.fields.items():
259-
args.namedargs[name] = self.ctx.join_variables(
260-
node, [t.instantiate(node) for t in typ.data])
257+
args.namedargs[name] = typ.instantiate(node)
261258
return self._new_instance(container, node, args)
262259

263260
def instantiate(self, node, container=None):
@@ -301,7 +298,7 @@ def _check_str_key(self, name):
301298

302299
def _check_str_key_value(self, node, name, value_var):
303300
self._check_str_key(name)
304-
typ = abstract_utils.get_atomic_value(self.fields[name])
301+
typ = self.fields[name]
305302
bad = self.ctx.matcher(node).compute_one_match(value_var, typ).bad_matches
306303
for match in bad:
307304
self.ctx.errorlog.annotation_type_mismatch(

pytype/tests/test_typed_dict.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,14 @@ def f() -> TD:
385385
return __any_object__
386386
""")
387387

388+
def test_duplicate_key(self):
389+
self.CheckWithErrors("""
390+
from typing_extensions import TypedDict
391+
class TD(TypedDict): # invalid-annotation
392+
x: int
393+
x: str
394+
""")
395+
388396

389397
class TypedDictFunctionalTest(test_base.BaseTest):
390398
"""Tests for typing.TypedDict functional constructor."""
@@ -458,6 +466,16 @@ class X(TypedDict, total=False):
458466
name: str
459467
""")
460468

469+
def test_ambiguous_field_type(self):
470+
self.CheckWithErrors("""
471+
from typing_extensions import TypedDict
472+
if __random__:
473+
v = str
474+
else:
475+
v = int
476+
X = TypedDict('X', {'k': v}) # invalid-annotation
477+
""")
478+
461479

462480
_SINGLE = """
463481
from typing import TypedDict

0 commit comments

Comments
 (0)