Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d32107f
Add utils related to namespace management
Viicos Oct 1, 2024
e5c4207
Add temporary xfailing tests that we will hopefully fix
Viicos Oct 1, 2024
9ef1965
Improve correctness in `get_cls_type_hints_lenient`
Viicos Oct 1, 2024
1da3698
Improve annotations resolving for functions
Viicos Oct 1, 2024
f4f2ae8
Adapt `collect_model_fields` to handle the parent namespace
Viicos Oct 1, 2024
bb7ffff
Adapt the `GenerateSchema` class to use the `NsResolver` class
Viicos Oct 1, 2024
2835c7f
Update `ModelMetaclass` to be compatible with the previous changes
Viicos Oct 1, 2024
9b77d05
Change `BaseModel.model_rebuid` logic to match the proposed spec
Viicos Oct 1, 2024
132c06c
Make the dataclass logic compatible with the proposed changes
Viicos Oct 1, 2024
b40468e
Refactor how we eval fields and generate a schema for dataclasses
Viicos Oct 1, 2024
cee89ad
Change `rebuild_dataclass` logic to match the proposed spec
Viicos Oct 1, 2024
24338c9
Change `validate_call` namespace logic
Viicos Oct 1, 2024
6a86ac8
Adapt `TypeAdapter` namespace logic
Viicos Oct 1, 2024
93bd021
Misc. changes to adapt with the new structures
Viicos Oct 1, 2024
7c3bf4f
Process some feedback
Viicos Oct 4, 2024
dd7e44c
Fix comment
Viicos Oct 7, 2024
7ae989e
WIP refactor to be more backwards compatible regarding parent ns
Viicos Oct 7, 2024
d2d549d
WIP fix remaining tests
Viicos Oct 8, 2024
e182033
Almost finished
Viicos Oct 8, 2024
89199ba
First round of feedback
Viicos Oct 8, 2024
127aefe
Cleanup xfail tests, last fixes
Viicos Oct 9, 2024
66c024e
Add back `eval_type_lenient` and deprecate it
Viicos Oct 9, 2024
2a371ca
lint
Viicos Oct 9, 2024
8d57b83
Fix pipeline example
Viicos Oct 9, 2024
7861286
Compat fixes, optimize `get_cls_type_hints`
Viicos Oct 9, 2024
08691ec
Update outdated comment
Viicos Oct 9, 2024
86fe4cc
Feedback
Viicos Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions docs/concepts/experimental.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ from datetime import datetime
from typing_extensions import Annotated

from pydantic import BaseModel
from pydantic.experimental.pipeline import validate_as, validate_as_deferred
from pydantic.experimental.pipeline import validate_as


class User(BaseModel):
Expand All @@ -71,10 +71,6 @@ class User(BaseModel):
),
]
friends: Annotated[list[User], validate_as(...).len(0, 100)] # (6)!
family: Annotated[ # (7)!
list[User],
validate_as_deferred(lambda: list[User]).transform(lambda x: x[1:]),
]
Comment on lines -74 to -77
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @adriangb I think this is broken by this PR but maybe worth breaking. I think it's possible to fix by fiddling with the __get_pydantic_core_schema__ on the Pipeline and making sure we plumb through the namespace stuff, but maybe harder than it's worth doing this second

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @Viicos decided this was worth breaking. As long as there's clear documentation showing the alternative path forward I'm okay with that.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy pasting what I added on Slack so that it's not lost:

the TL;DR is using lambdas with references to other symbols in annotations opens the door to a lot of weird behaviors.


This is the (simplified) test failing on the PR:

from typing_extensions import Annotated

from pydantic import BaseModel
from pydantic.experimental.pipeline import validate_as_deferred

class User(BaseModel):
    family: 'Annotated[list[User], validate_as_deferred(lambda: list[User])]'

# The `family` annotation is successfully evaluated, but at some point during
# core schema generation, the pipeline API logic is triggered and when the lambda
# gets called, we end up with:
# NameError: name 'User' is not defined

On main, when we evaluate (read: Python will call eval) the annotation for family, we use the following as globals:
{'User': __main__.User, 'Annotated': ..., 'BaseModel': ..., ...}
and locals are empty.

On this PR, we cleaned up how globals and locals were mixed up before. This means that we now use the following as globals:
{'Annotated': ..., 'BaseModel': ..., ...}
and locals:
{'User': __main__.User, ...}

And the issue comes from what could be considered as a limitation of eval. Consider this example:

def func():
    A = int

    works = lambda: list[A]
    fails = eval('lambda: list[A]', globals(), locals())

    works()
    # list[int]
    fails()
    # NameError: A is not defined.

The eval limitation is that it does not have access to the non-locals of the lambda environment (which is a new scope, like with a def statement). Even though A is present in locals, it won't be used to resolve A and so eval will look up in the globals instead (that's why it works on main because User was added in globals for the eval call).

This limitation is documented in this open CPython PR.

bio: Annotated[
datetime,
validate_as(int)
Expand All @@ -89,8 +85,7 @@ class User(BaseModel):
4. You can also use the lower level transform, constrain and predicate methods.
5. Use the `|` or `&` operators to combine steps (like a logical OR or AND).
6. Calling `validate_as(...)` with `Ellipsis`, `...` as the first positional argument implies `validate_as(<field type>)`. Use `validate_as(Any)` to accept any type.
7. For recursive types you can use `validate_as_deferred` to reference the type itself before it's defined.
8. You can call `validate_as()` before or after other steps to do pre or post processing.
7. You can call `validate_as()` before or after other steps to do pre or post processing.

### Mapping from `BeforeValidator`, `AfterValidator` and `WrapValidator`

Expand Down
29 changes: 16 additions & 13 deletions pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@
from ..errors import PydanticUndefinedAnnotation
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from . import _config, _decorators
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._namespace_utils import NsResolver
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature

if typing.TYPE_CHECKING:
from dataclasses import Field

from ..config import ConfigDict
from ..fields import FieldInfo

class StandardDataclass(typing.Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]

Expand Down Expand Up @@ -68,18 +71,20 @@ class PydanticDataclass(StandardDataclass, typing.Protocol):

def set_dataclass_fields(
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None = None,
ns_resolver: NsResolver | None = None,
config_wrapper: _config.ConfigWrapper | None = None,
) -> None:
"""Collect and set `cls.__pydantic_fields__`.

Args:
cls: The class.
types_namespace: The types namespace, defaults to `None`.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
config_wrapper: The config wrapper instance, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map, config_wrapper=config_wrapper)
fields = collect_dataclass_fields(
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
)

cls.__pydantic_fields__ = fields # type: ignore

Expand All @@ -89,7 +94,7 @@ def complete_dataclass(
config_wrapper: _config.ConfigWrapper,
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
ns_resolver: NsResolver | None = None,
_force_build: bool = False,
) -> bool:
"""Finish building a pydantic dataclass.
Expand All @@ -102,7 +107,8 @@ def complete_dataclass(
cls: The class.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors, defaults to `True`.
types_namespace: The types namespace.
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
and during schema building.
_force_build: Whether to force building the dataclass, no matter if
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.

Expand Down Expand Up @@ -135,16 +141,13 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)

if types_namespace is None:
types_namespace = _typing_extra.merge_cls_and_parent_ns(cls)
Comment thread
sydney-runkle marked this conversation as resolved.

set_dataclass_fields(cls, types_namespace, config_wrapper=config_wrapper)
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)

typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
ns_resolver=ns_resolver,
typevars_map=typevars_map,
)

# This needs to be called before we change the __init__
Expand Down
14 changes: 11 additions & 3 deletions pydantic/_internal/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..errors import PydanticUserError
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._namespace_utils import GlobalsNamespace, MappingNamespace
from ._typing_extra import get_function_type_hints

if TYPE_CHECKING:
Expand Down Expand Up @@ -752,7 +753,10 @@ def unwrap_wrapped_function(


def get_function_return_type(
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
func: Any,
explicit_return_type: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any:
"""Get the function return type.

Expand All @@ -762,15 +766,19 @@ def get_function_return_type(
Args:
func: The function to get its return type.
explicit_return_type: The explicit return type.
types_namespace: The types namespace, defaults to `None`.
globalns: The globals namespace to use during type annotation evaluation.
localns: The locals namespace to use during type annotation evaluation.

Returns:
The function return type.
"""
if explicit_return_type is PydanticUndefined:
# try to get it from the type annotation
hints = get_function_type_hints(
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
unwrap_wrapped_function(func),
include_keys={'return'},
globalns=globalns,
localns=localns,
)
return hints.get('return', PydanticUndefined)
else:
Expand Down
101 changes: 59 additions & 42 deletions pydantic/_internal/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations as _annotations

import dataclasses
import sys
import warnings
from copy import copy
from functools import lru_cache
Expand All @@ -17,8 +16,9 @@
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._import_utils import import_cached_base_model, import_cached_field_info
from ._namespace_utils import NsResolver
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, is_classvar, is_finalvar
from ._typing_extra import get_cls_type_hints, is_classvar, is_finalvar

if TYPE_CHECKING:
from annotated_types import BaseMetadata
Expand Down Expand Up @@ -73,7 +73,7 @@ def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any] | None,
ns_resolver: NsResolver | None,
*,
typevars_map: dict[Any, Any] | None = None,
) -> tuple[dict[str, FieldInfo], set[str]]:
Expand All @@ -87,7 +87,7 @@ def collect_model_fields( # noqa: C901
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
types_namespace: Optional extra namespace to look for types in.
ns_resolver: Namespace resolver to use when getting model annotations.
typevars_map: A dictionary mapping type variables to their concrete types.

Returns:
Expand All @@ -107,7 +107,7 @@ def collect_model_fields( # noqa: C901
if model_fields := getattr(base, '__pydantic_fields__', None):
parent_fields_lookup.update(model_fields)

type_hints = get_cls_type_hints_lenient(cls, types_namespace)
type_hints = get_cls_type_hints(cls, ns_resolver=ns_resolver, lenient=True)

# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
Expand Down Expand Up @@ -232,7 +232,7 @@ def collect_model_fields( # noqa: C901

if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
field.apply_typevars_map(typevars_map)

_update_fields_from_docstrings(cls, fields, config_wrapper)
return fields, class_vars
Expand Down Expand Up @@ -269,16 +269,17 @@ def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:

def collect_dataclass_fields(
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None,
*,
ns_resolver: NsResolver | None = None,
typevars_map: dict[Any, Any] | None = None,
config_wrapper: ConfigWrapper | None = None,
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.

Args:
cls: dataclass.
types_namespace: Optional extra namespace to look for types in.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
Defaults to an empty instance.
typevars_map: A dictionary mapping type variables to their concrete types.
config_wrapper: The config wrapper instance.

Expand All @@ -288,50 +289,66 @@ def collect_dataclass_fields(
FieldInfo_ = import_cached_field_info()

fields: dict[str, FieldInfo] = {}
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
ns_resolver = ns_resolver or NsResolver()
dataclass_fields = cls.__dataclass_fields__

# The logic here is similar to `_typing_extra.get_cls_type_hints`,
# although we do it manually as stdlib dataclasses already have annotations
# collected in each class:
for base in reversed(cls.__mro__):
if not _typing_extra.is_dataclass(base):
continue

source_module = sys.modules.get(cls.__module__)
if source_module is not None:
types_namespace = {**source_module.__dict__, **(types_namespace or {})}
with ns_resolver.push(base):
for ann_name, dataclass_field in dataclass_fields.items():
if ann_name not in base.__dict__.get('__annotations__', {}):
# `__dataclass_fields__`contains every field, even the ones from base classes.
# Only collect the ones defined on `base`.
continue

for ann_name, dataclass_field in dataclass_fields.items():
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
if is_classvar(ann_type):
continue
globalns, localns = ns_resolver.types_namespace
ann_type = _typing_extra.eval_type(dataclass_field.type, globalns, localns, lenient=True)

if (
not dataclass_field.init
and dataclass_field.default == dataclasses.MISSING
and dataclass_field.default_factory == dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
if is_classvar(ann_type):
continue

if isinstance(dataclass_field.default, FieldInfo_):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)
if (
not dataclass_field.init
and dataclass_field.default == dataclasses.MISSING
and dataclass_field.default_factory == dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue

# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field)
if isinstance(dataclass_field.default, FieldInfo_):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)

fields[ann_name] = field_info
# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field)

fields[ann_name] = field_info

if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo_):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if field_info.default is not PydanticUndefined and isinstance(
getattr(cls, ann_name, field_info), FieldInfo_
):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)

if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
# We don't pass any ns, as `field.annotation`
# was already evaluated. TODO: is this method relevant?
# Can't we juste use `_generics.replace_types`?
field.apply_typevars_map(typevars_map)

if config_wrapper is not None:
_update_fields_from_docstrings(cls, fields, config_wrapper)
Expand Down
Loading