Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion pydantic/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,33 @@ class PydanticJsonSchemaWarning(UserWarning):
CoreModeRef = tuple[CoreRef, JsonSchemaMode]
JsonSchemaKeyT = TypeVar('JsonSchemaKeyT', bound=Hashable)

# ##### Regex for Decimal JSON Schema Generation #####

_DECIMAL_JSON_VALIDATION_MAX_DIGIT_LOOKAHEAD_PATTERN = (
r'(?=(?:\d(\.)?){{1,{max_digits}}}' # Positive lookahead for max_digits and optional decimal place
r'(?(1)0*)$)' # yes-pattern allowing trailing zeroes if the decimal place exists
)

_DECIMAL_JSON_SERIALIZATION_MAX_DIGIT_LOOKAHEAD_PATTERN = (
r'(?=(?:\d\.?){{1,{max_digits}}}$)' # Positive lookahead for max_digits and optional decimal place
)

_DECIMAL_JSON_VALIDATION_PATTERN = (
r'^-?' # Minus sign (optional)
r'0*' # Allow leading zeroes
r'{max_digit_lookahead}' # Substitution for max digit lookahead if required
r'\d{{1,{integer_places}}}' # One or more integer digits
r'(?:\.\d{{0,{decimal_places}}}0*)?' # Optional non-capturing group: decimal digits
r'$'
)

_DECIMAL_JSON_SERIALIZATION_PATTERN = (
r'^-?' # Minus sign (optional)
r'{max_digit_lookahead}' # Substitution for max digit lookahead if required
r'(?:0|[1-9]\d{{0,{integer_places}}})' # Non-capturing group: Single zero OR non-zero digit and integer digits
r'(?:\.\d{{0,{decimal_places}}})?$' # Optional non-capturing group: decimal digits
)


@dataclasses.dataclass(**_internal_dataclass.slots_true)
class _DefinitionsRemapping:
Expand Down Expand Up @@ -674,7 +701,41 @@ def decimal_schema(self, schema: core_schema.DecimalSchema) -> JsonSchemaValue:
Returns:
The generated JSON schema.
"""
json_schema = self.str_schema(core_schema.str_schema())
max_digits = schema.get('max_digits')
decimal_places = schema.get('decimal_places')
str_schema = core_schema.str_schema()
# Only add a pattern if either max_digits or decimal_places is set
if max_digits is not None or decimal_places is not None:
max_digit_lookahead = ''
if self.mode == 'validation':
# Only set a max digit lookahead if max_digits is set
if max_digits is not None:
max_digit_lookahead = _DECIMAL_JSON_VALIDATION_MAX_DIGIT_LOOKAHEAD_PATTERN.format(
max_digits=max_digits
)
integer_places = '' if max_digits is None or decimal_places is None else max_digits - decimal_places
decimal_regex_pattern = _DECIMAL_JSON_VALIDATION_PATTERN.format(
max_digit_lookahead=max_digit_lookahead,
integer_places=integer_places,
decimal_places='' if decimal_places is None else decimal_places,
)
str_schema['pattern'] = re.compile(decimal_regex_pattern).pattern
elif self.mode == 'serialization':
if max_digits is not None:
max_digit_lookahead = _DECIMAL_JSON_SERIALIZATION_MAX_DIGIT_LOOKAHEAD_PATTERN.format(
max_digits=max_digits
)
# For the serialization pattern we match the first integer digit separate from the rest, to account for
# this we want our integer_places argument to be one less than max_digits - decimal_places
integer_places = '' if max_digits is None or decimal_places is None else max_digits - decimal_places - 1
decimal_regex_pattern = _DECIMAL_JSON_SERIALIZATION_PATTERN.format(
max_digit_lookahead=max_digit_lookahead,
integer_places=integer_places,
decimal_places='' if decimal_places is None else decimal_places,
)
str_schema['pattern'] = re.compile(decimal_regex_pattern).pattern
json_schema = self.str_schema(str_schema)

if self.mode == 'validation':
multiple_of = schema.get('multiple_of')
le = schema.get('le')
Expand Down
48 changes: 48 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,37 @@ class A(BaseModel):
({'ge': 2}, Decimal, {'anyOf': [{'type': 'number', 'minimum': 2}, {'type': 'string'}]}),
({'le': 5}, Decimal, {'anyOf': [{'type': 'number', 'maximum': 5}, {'type': 'string'}]}),
({'multiple_of': 5}, Decimal, {'anyOf': [{'type': 'number', 'multipleOf': 5}, {'type': 'string'}]}),
(
{'max_digits': 4},
Decimal,
{
'anyOf': [
{'type': 'number'},
{
'pattern': '^-?0*(?=(?:\\d(\\.)?){1,4}(?(1)0*)$)\\d{1,}(?:\\.\\d{0,}0*)?$',
'type': 'string',
},
]
},
),
(
{'decimal_places': 2},
Decimal,
{'anyOf': [{'type': 'number'}, {'pattern': '^-?0*\\d{1,}(?:\\.\\d{0,2}0*)?$', 'type': 'string'}]},
),
(
{'max_digits': 4, 'decimal_places': 2},
Decimal,
{
'anyOf': [
{'type': 'number'},
{
'pattern': '^-?0*(?=(?:\\d(\\.)?){1,4}(?(1)0*)$)\\d{1,2}(?:\\.\\d{0,2}0*)?$',
'type': 'string',
},
]
},
),
],
)
def test_constraints_schema_validation(kwargs, type_, expected_extra):
Expand Down Expand Up @@ -2055,6 +2086,23 @@ class Foo(BaseModel):
({'ge': 2}, Decimal, {'type': 'string'}),
({'le': 5}, Decimal, {'type': 'string'}),
({'multiple_of': 5}, Decimal, {'type': 'string'}),
(
{'max_digits': 4},
Decimal,
{
'pattern': '^-?(?=(?:\\d\\.?){1,4}$)(?:0|[1-9]\\d{0,})(?:\\.\\d{0,})?$',
'type': 'string',
},
),
({'decimal_places': 2}, Decimal, {'pattern': '^-?(?:0|[1-9]\\d{0,})(?:\\.\\d{0,2})?$', 'type': 'string'}),
(
{'max_digits': 4, 'decimal_places': 2},
Decimal,
{
'pattern': '^-?(?=(?:\\d\\.?){1,4}$)(?:0|[1-9]\\d{0,1})(?:\\.\\d{0,2})?$',
'type': 'string',
},
),
],
)
def test_constraints_schema_serialization(kwargs, type_, expected_extra):
Expand Down