Skip to content

Commit c699398

Browse files
committed
Fix [invalid-function-definition] error caused by ignoring dataclasses.KW_ONLY.
* Respects KW_ONLY for dataclass attributes that don't use dataclasses.field. * Removes unnecessary version check for applying kwonly-ness: it's impossible for a user to even set KW_ONLY or kw_only unless they're using a version where the field/parameter is available, so once it's been set, we can safely assume we're in a high-enough version. PiperOrigin-RevId: 597970055
1 parent 15befd9 commit c699398

3 files changed

Lines changed: 51 additions & 3 deletions

File tree

pytype/overlays/dataclass_overlay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def decorate(self, node, cls):
9797
continue
9898
kind = ""
9999
init = True
100-
kw_only = False
100+
kw_only = sticky_kwonly
101101
assert typ
102102
if match_classvar(typ):
103103
continue
@@ -112,8 +112,8 @@ def decorate(self, node, cls):
112112
field = orig.data[0]
113113
orig = field.default
114114
init = field.init
115-
if self.ctx.python_version >= (3, 10):
116-
kw_only = sticky_kwonly if field.kw_only is None else field.kw_only
115+
if field.kw_only is not None:
116+
kw_only = field.kw_only
117117

118118
if orig and orig.data == [self.ctx.convert.none]:
119119
# vm._apply_annotation mostly takes care of checking that the default

pytype/tests/test_dataclasses.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,26 @@ class A:
761761
def __init__(self, a1: int, a3: int, *, a2: int = ...) -> None: ...
762762
""")
763763

764+
@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
765+
def test_kwonly_and_nonfield_default(self):
766+
ty = self.Infer("""
767+
import dataclasses
768+
@dataclasses.dataclass
769+
class C:
770+
_: dataclasses.KW_ONLY
771+
x: int = 0
772+
y: str
773+
""")
774+
self.assertTypesMatchPytd(ty, """
775+
import dataclasses
776+
@dataclasses.dataclass
777+
class C:
778+
x: int = ...
779+
y: str
780+
_: dataclasses.KW_ONLY
781+
def __init__(self, *, x: int = ..., y: str) -> None: ...
782+
""")
783+
764784
def test_star_import(self):
765785
with self.DepTree([("foo.pyi", """
766786
import dataclasses

pytype/tests/test_flax_overlay.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,34 @@ def __init__(
266266
def replace(self: _TBaz, **kwargs) -> _TBaz: ...
267267
""")
268268

269+
@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
270+
def test_kwonly(self):
271+
with test_utils.Tempdir() as d:
272+
self._setup_linen_pyi(d)
273+
ty = self.Infer("""
274+
import dataclasses
275+
from flax import linen as nn
276+
class C(nn.Module):
277+
_: dataclasses.KW_ONLY
278+
x: int = 0
279+
y: str
280+
""", pythonpath=[d.path])
281+
self.assertTypesMatchPytd(ty, """
282+
import dataclasses
283+
from flax import linen as nn
284+
from typing import Any, TypeVar
285+
286+
_TC = TypeVar('_TC', bound=C)
287+
288+
@dataclasses.dataclass
289+
class C(nn.module.Module):
290+
x: int = ...
291+
y: str
292+
_: dataclasses.KW_ONLY
293+
def __init__(self, *, x: int = ..., y: str, name: str = ..., parent: Any = ...) -> None: ...
294+
def replace(self: _TC, **kwargs) -> _TC: ...
295+
""")
296+
269297

270298
if __name__ == "__main__":
271299
test_base.main()

0 commit comments

Comments
 (0)