Skip to content

Commit a4d7003

Browse files
authored
Merge pull request #1600 from google/google_sync
Add 'default' field to pytd.TypeParameter.
2 parents b899682 + bf32914 commit a4d7003

6 files changed

Lines changed: 48 additions & 10 deletions

File tree

pytype/pyi/definitions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,13 @@ def add_type_variable(self, name, tvar):
468468
raise _ParseError(f"{tvar.kind} name needs to be {tvar.name!r} "
469469
f"(not {name!r})")
470470
bound = tvar.bound
471-
if isinstance(bound, str):
472-
bound = pytd.NamedType(bound)
473471
constraints = tuple(tvar.constraints) if tvar.constraints else ()
472+
if isinstance(tvar.default, list):
473+
default = tuple(tvar.default)
474+
else:
475+
default = tvar.default
474476
self.type_params.append(pytd_type(
475-
name=name, constraints=constraints, bound=bound))
477+
name=name, constraints=constraints, bound=bound, default=default))
476478

477479
def add_import(self, from_package, import_list):
478480
"""Add an import.

pytype/pyi/parser.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import sys
1010
import tokenize
11-
from typing import Any, List, Optional, Tuple, cast
11+
from typing import Any, List, Optional, Tuple, Union, cast
1212

1313
from pytype.ast import debug
1414
from pytype.pyi import conditions
@@ -61,6 +61,7 @@ class _TypeVariable:
6161
name: str
6262
bound: Optional[pytd.Type]
6363
constraints: List[pytd.Type]
64+
default: Optional[Union[pytd.Type, List[pytd.Type]]]
6465

6566
@classmethod
6667
def from_call(cls, kind: str, node: astlib.Call):
@@ -72,18 +73,20 @@ def from_call(cls, kind: str, node: astlib.Call):
7273
if not types.Pyval.is_str(name):
7374
raise ParseError(f"Bad arguments to {kind}")
7475
bound = None
75-
# 'bound' is the only keyword argument we currently use.
76+
default = None
7677
# TODO(rechen): We should enforce the PEP 484 guideline that
7778
# len(constraints) != 1. However, this guideline is currently violated
7879
# in typeshed (see https://github.com/python/typeshed/pull/806).
7980
kws = {x.arg for x in node.keywords}
80-
extra = kws - {"bound", "covariant", "contravariant"}
81+
extra = kws - {"bound", "covariant", "contravariant", "default"}
8182
if extra:
8283
raise ParseError(f"Unrecognized keyword(s): {', '.join(extra)}")
8384
for kw in node.keywords:
8485
if kw.arg == "bound":
8586
bound = kw.value
86-
return cls(kind, name.value, bound, constraints)
87+
elif kw.arg == "default":
88+
default = kw.value
89+
return cls(kind, name.value, bound, constraints, default)
8790

8891
#------------------------------------------------------
8992
# Main tree visitor and generator code
@@ -674,6 +677,8 @@ def _convert_typevar_args(self, node: astlib.Call):
674677
for kw in node.keywords:
675678
if kw.arg == "bound":
676679
kw.value = self.annotation_visitor.visit(kw.value)
680+
elif kw.arg == "default":
681+
kw.value = self.annotation_visitor.visit(kw.value)
677682

678683
def _convert_typed_dict_args(self, node: astlib.Call):
679684
for fields in node.args[1:]:
@@ -682,7 +687,8 @@ def _convert_typed_dict_args(self, node: astlib.Call):
682687
def enter_Call(self, node):
683688
node.func = self.annotation_visitor.visit(node.func)
684689
func = node.func.name or ""
685-
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec")):
690+
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec",
691+
"typing.TypeVarTuple")):
686692
self._convert_typevar_args(node)
687693
elif self.defs.matches_type(func, "typing.NamedTuple"):
688694
self._convert_typing_namedtuple_args(node)

pytype/pyi/parser_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3526,5 +3526,30 @@ def f(x: Tuple[Any, ...]) -> Any: ...
35263526
""")
35273527

35283528

3529+
class TypeParameterDefaultTest(parser_test_base.ParserTestBase):
3530+
3531+
def test_typevar(self):
3532+
self.check("""
3533+
from typing_extensions import TypeVar
3534+
3535+
T = TypeVar('T', default=int)
3536+
""")
3537+
3538+
def test_paramspec(self):
3539+
self.check("""
3540+
from typing_extensions import ParamSpec
3541+
3542+
P = ParamSpec('P', default=[str, int])
3543+
""")
3544+
3545+
def test_typevartuple(self):
3546+
self.check("""
3547+
from typing_extensions import TypeVarTuple, Unpack
3548+
Ts = TypeVarTuple('Ts', default=Unpack[tuple[str, int]])
3549+
""", """
3550+
from typing_extensions import TypeVarTuple, TypeVarTuple as Ts, Unpack
3551+
""")
3552+
3553+
35293554
if __name__ == "__main__":
35303555
unittest.main()

pytype/pytd/printer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def _FormatTypeParams(self, type_params):
209209
args += [self.Print(c) for c in t.constraints]
210210
if t.bound:
211211
args.append(f"bound={self.Print(t.bound)}")
212+
if isinstance(t.default, tuple):
213+
args.append(
214+
f"default=[{', '.join(self.Print(d) for d in t.default)}]")
215+
elif t.default:
216+
args.append(f"default={self.Print(t.default)}")
212217
if isinstance(t, pytd.ParamSpec):
213218
typename = self._LookupTypingMember("ParamSpec")
214219
else:

pytype/pytd/pytd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def f(x: T) -> T
344344
name: str
345345
constraints: Tuple[TypeU, ...] = ()
346346
bound: Optional[TypeU] = None
347+
default: Optional[Union[TypeU, Tuple[TypeU, ...]]] = None
347348
scope: Optional[str] = None
348349

349350
def __lt__(self, other):

pytype/pytd/visitors_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,7 @@ class A(Dict[T, T], Generic[T]): pass
552552
""")
553553
a = ast.Lookup("A")
554554
self.assertEqual(
555-
(pytd.TemplateItem(pytd.TypeParameter("T", (), None, "A")),),
556-
a.template)
555+
(pytd.TemplateItem(pytd.TypeParameter("T", scope="A")),), a.template)
557556

558557
def test_adjust_type_parameters_with_duplicates_in_generic(self):
559558
src = textwrap.dedent("""

0 commit comments

Comments
 (0)