|
6 | 6 | from collections.abc import Mapping, Sequence |
7 | 7 | from dataclasses import is_dataclass |
8 | 8 | from enum import Enum |
9 | | -from typing import Any, cast, get_args, get_origin |
| 9 | +from typing import Any, TypeVar, cast, get_args, get_origin |
10 | 10 |
|
11 | 11 | from pydantic import BaseModel, Json, RootModel, Secret |
12 | 12 | from pydantic._internal._utils import is_model_class |
@@ -40,11 +40,49 @@ def parse_env_vars( |
40 | 40 | } |
41 | 41 |
|
42 | 42 |
|
| 43 | +def _substitute_typevars(tp: Any, param_map: dict[Any, Any]) -> Any: |
| 44 | + """Substitute TypeVars in a type annotation with concrete types from param_map.""" |
| 45 | + if isinstance(tp, TypeVar) and tp in param_map: |
| 46 | + return param_map[tp] |
| 47 | + args = get_args(tp) |
| 48 | + if not args: |
| 49 | + return tp |
| 50 | + new_args = tuple(_substitute_typevars(arg, param_map) for arg in args) |
| 51 | + if new_args == args: |
| 52 | + return tp |
| 53 | + origin = get_origin(tp) |
| 54 | + if origin is not None: |
| 55 | + try: |
| 56 | + return origin[new_args] |
| 57 | + except TypeError: |
| 58 | + # types.UnionType and similar are not directly subscriptable, |
| 59 | + # reconstruct using | operator |
| 60 | + import functools |
| 61 | + import operator |
| 62 | + |
| 63 | + return functools.reduce(operator.or_, new_args) |
| 64 | + return tp |
| 65 | + |
| 66 | + |
| 67 | +def _resolve_type_alias(annotation: Any) -> Any: |
| 68 | + """Resolve a TypeAliasType to its underlying value, substituting type params if parameterized.""" |
| 69 | + if typing_objects.is_typealiastype(annotation): |
| 70 | + return annotation.__value__ |
| 71 | + origin = get_origin(annotation) |
| 72 | + if typing_objects.is_typealiastype(origin): |
| 73 | + type_params = getattr(origin, '__type_params__', ()) |
| 74 | + type_args = get_args(annotation) |
| 75 | + value = origin.__value__ |
| 76 | + if type_params and type_args: |
| 77 | + return _substitute_typevars(value, dict(zip(type_params, type_args))) |
| 78 | + return value |
| 79 | + return annotation |
| 80 | + |
| 81 | + |
43 | 82 | def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool: |
44 | 83 | # If the model is a root model, the root annotation should be used to |
45 | 84 | # evaluate the complexity. |
46 | | - if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)): |
47 | | - annotation = annotation.__value__ |
| 85 | + annotation = _resolve_type_alias(annotation) |
48 | 86 | if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel: |
49 | 87 | annotation = cast('type[RootModel[Any]]', annotation) |
50 | 88 | root_annotation = annotation.model_fields['root'].annotation |
@@ -74,10 +112,8 @@ def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool: |
74 | 112 |
|
75 | 113 |
|
76 | 114 | def _get_field_metadata(field: FieldInfo) -> list[Any]: |
77 | | - annotation = field.annotation |
| 115 | + annotation = _resolve_type_alias(field.annotation) |
78 | 116 | metadata = field.metadata |
79 | | - if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)): |
80 | | - annotation = annotation.__value__ # type: ignore[union-attr] |
81 | 117 | origin = get_origin(annotation) |
82 | 118 | if typing_objects.is_annotated(origin): |
83 | 119 | _, *meta = get_args(annotation) |
@@ -240,6 +276,7 @@ def _is_function(obj: Any) -> bool: |
240 | 276 | '_get_model_fields', |
241 | 277 | '_is_function', |
242 | 278 | '_parse_env_none_str', |
| 279 | + '_resolve_type_alias', |
243 | 280 | '_strip_annotated', |
244 | 281 | '_union_is_complex', |
245 | 282 | 'parse_env_vars', |
|
0 commit comments