Skip to content

Commit cf55bad

Browse files
authored
♻️ Simplify internals, remove Pydantic v1 only logic, no longer needed (#14857)
1 parent ac8362c commit cf55bad

5 files changed

Lines changed: 19 additions & 112 deletions

File tree

fastapi/_compat/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from .v2 import create_body_model as create_body_model
2626
from .v2 import evaluate_forwardref as evaluate_forwardref
2727
from .v2 import get_cached_model_fields as get_cached_model_fields
28-
from .v2 import get_compat_model_name_map as get_compat_model_name_map
2928
from .v2 import get_definitions as get_definitions
3029
from .v2 import get_missing_field_error as get_missing_field_error
3130
from .v2 import get_schema_from_model_field as get_schema_from_model_field

fastapi/_compat/v2.py

Lines changed: 10 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
import warnings
33
from collections.abc import Sequence
4-
from copy import copy, deepcopy
4+
from copy import copy
55
from dataclasses import dataclass, is_dataclass
66
from enum import Enum
77
from functools import lru_cache
@@ -169,11 +169,11 @@ def validate(
169169
values: dict[str, Any] = {}, # noqa: B006
170170
*,
171171
loc: tuple[Union[int, str], ...] = (),
172-
) -> tuple[Any, Union[list[dict[str, Any]], None]]:
172+
) -> tuple[Any, list[dict[str, Any]]]:
173173
try:
174174
return (
175175
self._type_adapter.validate_python(value, from_attributes=True),
176-
None,
176+
[],
177177
)
178178
except ValidationError as exc:
179179
return None, _regenerate_error_with_loc(
@@ -305,94 +305,12 @@ def get_definitions(
305305
if "description" in item_def:
306306
item_description = cast(str, item_def["description"]).split("\f")[0]
307307
item_def["description"] = item_description
308-
new_mapping, new_definitions = _remap_definitions_and_field_mappings(
309-
model_name_map=model_name_map,
310-
definitions=definitions, # type: ignore[arg-type]
311-
field_mapping=field_mapping,
312-
)
313-
return new_mapping, new_definitions
314-
315-
316-
def _replace_refs(
317-
*,
318-
schema: dict[str, Any],
319-
old_name_to_new_name_map: dict[str, str],
320-
) -> dict[str, Any]:
321-
new_schema = deepcopy(schema)
322-
for key, value in new_schema.items():
323-
if key == "$ref":
324-
value = schema["$ref"]
325-
if isinstance(value, str):
326-
ref_name = schema["$ref"].split("/")[-1]
327-
if ref_name in old_name_to_new_name_map:
328-
new_name = old_name_to_new_name_map[ref_name]
329-
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
330-
continue
331-
if isinstance(value, dict):
332-
new_schema[key] = _replace_refs(
333-
schema=value,
334-
old_name_to_new_name_map=old_name_to_new_name_map,
335-
)
336-
elif isinstance(value, list):
337-
new_value = []
338-
for item in value:
339-
if isinstance(item, dict):
340-
new_item = _replace_refs(
341-
schema=item,
342-
old_name_to_new_name_map=old_name_to_new_name_map,
343-
)
344-
new_value.append(new_item)
345-
346-
else:
347-
new_value.append(item)
348-
new_schema[key] = new_value
349-
return new_schema
350-
351-
352-
def _remap_definitions_and_field_mappings(
353-
*,
354-
model_name_map: ModelNameMap,
355-
definitions: dict[str, Any],
356-
field_mapping: dict[
357-
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
358-
],
359-
) -> tuple[
360-
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
361-
dict[str, Any],
362-
]:
363-
old_name_to_new_name_map = {}
364-
for field_key, schema in field_mapping.items():
365-
model = field_key[0].type_
366-
if model not in model_name_map or "$ref" not in schema:
367-
continue
368-
new_name = model_name_map[model]
369-
old_name = schema["$ref"].split("/")[-1]
370-
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
371-
continue
372-
old_name_to_new_name_map[old_name] = new_name
373-
374-
new_field_mapping: dict[
375-
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
376-
] = {}
377-
for field_key, schema in field_mapping.items():
378-
new_schema = _replace_refs(
379-
schema=schema,
380-
old_name_to_new_name_map=old_name_to_new_name_map,
381-
)
382-
new_field_mapping[field_key] = new_schema
383-
384-
new_definitions = {}
385-
for key, value in definitions.items():
386-
if key in old_name_to_new_name_map:
387-
new_key = old_name_to_new_name_map[key]
388-
else:
389-
new_key = key
390-
new_value = _replace_refs(
391-
schema=value,
392-
old_name_to_new_name_map=old_name_to_new_name_map,
393-
)
394-
new_definitions[new_key] = new_value
395-
return new_field_mapping, new_definitions
308+
# definitions: dict[DefsRef, dict[str, Any]]
309+
# but mypy complains about general str in other places that are not declared as
310+
# DefsRef, although DefsRef is just str:
311+
# DefsRef = NewType('DefsRef', str)
312+
# So, a cast to simplify the types here
313+
return field_mapping, cast(dict[str, dict[str, Any]], definitions)
396314

397315

398316
def is_scalar_field(field: ModelField) -> bool:
@@ -441,7 +359,7 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
441359
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
442360

443361

444-
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
362+
def get_missing_field_error(loc: tuple[Union[int, str], ...]) -> dict[str, Any]:
445363
error = ValidationError.from_exception_data(
446364
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
447365
).errors(include_url=False)[0]
@@ -499,11 +417,6 @@ def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str
499417
return {v: k for k, v in name_model_map.items()}
500418

501419

502-
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
503-
flat_models = get_flat_models_from_fields(fields, known_models=set())
504-
return get_model_name_map(flat_models)
505-
506-
507420
def get_flat_models_from_model(
508421
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
509422
) -> TypeModelSet:

fastapi/dependencies/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ModelField,
2222
RequiredParam,
2323
Undefined,
24-
_regenerate_error_with_loc,
2524
copy_field_info,
2625
create_body_model,
2726
evaluate_forwardref,
@@ -718,12 +717,7 @@ def _validate_value_with_model_field(
718717
return None, [get_missing_field_error(loc=loc)]
719718
else:
720719
return deepcopy(field.default), []
721-
v_, errors_ = field.validate(value, values, loc=loc)
722-
if isinstance(errors_, list):
723-
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
724-
return None, new_errors
725-
else:
726-
return v_, []
720+
return field.validate(value, values, loc=loc)
727721

728722

729723
def _is_json_field(field: ModelField) -> bool:

fastapi/openapi/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from fastapi._compat import (
1010
ModelField,
1111
Undefined,
12-
get_compat_model_name_map,
1312
get_definitions,
1413
get_schema_from_model_field,
1514
lenient_issubclass,
1615
)
16+
from fastapi._compat.v2 import (
17+
get_flat_models_from_fields,
18+
get_model_name_map,
19+
)
1720
from fastapi.datastructures import DefaultPlaceholder
1821
from fastapi.dependencies.models import Dependant
1922
from fastapi.dependencies.utils import (
@@ -512,7 +515,8 @@ def get_openapi(
512515
webhook_paths: dict[str, dict[str, Any]] = {}
513516
operation_ids: set[str] = set()
514517
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
515-
model_name_map = get_compat_model_name_map(all_fields)
518+
flat_models = get_flat_models_from_fields(all_fields, known_models=set())
519+
model_name_map = get_model_name_map(flat_models)
516520
field_mapping, definitions = get_definitions(
517521
fields=all_fields,
518522
model_name_map=model_name_map,

fastapi/routing.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,12 @@ async def serialize_response(
277277
endpoint_ctx: Optional[EndpointContext] = None,
278278
) -> Any:
279279
if field:
280-
errors = []
281280
if is_coroutine:
282-
value, errors_ = field.validate(response_content, {}, loc=("response",))
281+
value, errors = field.validate(response_content, {}, loc=("response",))
283282
else:
284-
value, errors_ = await run_in_threadpool(
283+
value, errors = await run_in_threadpool(
285284
field.validate, response_content, {}, loc=("response",)
286285
)
287-
if isinstance(errors_, list):
288-
errors.extend(errors_)
289286
if errors:
290287
ctx = endpoint_ctx or EndpointContext()
291288
raise ResponseValidationError(

0 commit comments

Comments
 (0)