Skip to content

Commit 6835ea3

Browse files
authored
fix: Don't turn items annotated as InitVar into dataclass members
PR-252: #252
1 parent c88b484 commit 6835ea3

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/griffe/extensions/dataclasses.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,22 @@ def _set_dataclass_init(class_: Class) -> None:
176176
class_.set_member("__init__", init)
177177

178178

179+
def _del_members_annotated_as_initvar(class_: Class) -> None:
180+
# Definitions annotated as InitVar are not class members
181+
attributes = [member for member in class_.members.values() if isinstance(member, Attribute)]
182+
for attribute in attributes:
183+
if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar":
184+
class_.del_member(attribute.name)
185+
186+
179187
def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
180188
if mod_cls.canonical_path in processed:
181189
return
182190
processed.add(mod_cls.canonical_path)
183191
if isinstance(mod_cls, Class):
184192
if "__init__" not in mod_cls.members:
185193
_set_dataclass_init(mod_cls)
194+
_del_members_annotated_as_initvar(mod_cls)
186195
for member in mod_cls.members.values():
187196
if not member.is_alias and member.is_class:
188197
_apply_recursively(member, processed) # type: ignore[arg-type]

tests/test_dataclasses.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,36 @@ class Reordered(Base):
276276
assert [p.name for p in params_base] == ["self", "a", "b"]
277277
assert [p.name for p in params_reordered] == ["self", "b", "c", "a"]
278278
assert str(params_reordered["b"].annotation) == "float"
279+
280+
281+
def test_parameters_annotated_as_initvar() -> None:
282+
"""Don't return InitVar annotated fields as class members.
283+
284+
But if __init__ is defined, InitVar has no effect.
285+
"""
286+
code = """
287+
from dataclasses import dataclass, InitVar
288+
289+
@dataclass
290+
class PointA:
291+
x: float
292+
y: float
293+
z: InitVar[float]
294+
295+
@dataclass
296+
class PointB:
297+
x: float
298+
y: float
299+
z: InitVar[float]
300+
301+
def __init__(self, r: float): ...
302+
"""
303+
304+
with temporary_visited_package("package", {"__init__.py": code}) as module:
305+
point_a = module["PointA"]
306+
assert ["self", "x", "y", "z"] == [p.name for p in point_a.parameters]
307+
assert ["x", "y", "__init__"] == list(point_a.members)
308+
309+
point_b = module["PointB"]
310+
assert ["self", "r"] == [p.name for p in point_b.parameters]
311+
assert ["x", "y", "z", "__init__"] == list(point_b.members)

0 commit comments

Comments
 (0)