Skip to content

Commit ce9804c

Browse files
authored
CLI coerce numeric types. (#769)
1 parent e460f0b commit ce9804c

2 files changed

Lines changed: 29 additions & 5 deletions

File tree

pydantic_settings/sources/providers/cli.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838

3939
import typing_extensions
40-
from pydantic import AliasChoices, AliasPath, BaseModel, Field, PrivateAttr
40+
from pydantic import AliasChoices, AliasPath, BaseModel, Field, PrivateAttr, TypeAdapter, ValidationError
4141
from pydantic._internal._repr import Representation
4242
from pydantic._internal._utils import is_model_class
4343
from pydantic.dataclasses import is_pydantic_dataclass
@@ -619,6 +619,11 @@ def _merged_list_to_str(self, merged_list: list[str], field_name: str) -> str:
619619
decode_list: list[str] = []
620620
is_use_decode: bool | None = None
621621
cli_arg_map = self._parser_map.get(field_name, {})
622+
try:
623+
list_adapter: Any = TypeAdapter(next(iter(cli_arg_map.values())).field_info.annotation)
624+
is_num_type_str = type(list_adapter.validate_python(['1'])[0]) is str
625+
except (StopIteration, ValidationError):
626+
is_num_type_str = None
622627
for index, item in enumerate(merged_list):
623628
cli_arg = cli_arg_map.get(index)
624629
is_decode = cli_arg is None or not cli_arg.is_no_decode
@@ -628,6 +633,12 @@ def _merged_list_to_str(self, merged_list: list[str], field_name: str) -> str:
628633
raise SettingsError('Mixing Decode and NoDecode across different AliasPath fields is not allowed')
629634
if is_use_decode:
630635
item = item.replace('\\', '\\\\')
636+
try:
637+
unquoted_item = item[1:-1] if item.startswith('"') and item.endswith('"') else item
638+
float(unquoted_item)
639+
item = f'"{unquoted_item}"' if is_num_type_str else unquoted_item
640+
except ValueError:
641+
pass
631642
elif item.startswith('"') and item.endswith('"'):
632643
item = item[1:-1]
633644
decode_list.append(item)

tests/test_source_cli.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,15 +745,28 @@ def check_answer(cfg, prefix, expected):
745745
assert cfg.model_dump() == expected
746746

747747
args: list[str] = []
748-
args = [f'--{prefix}num_list', arg_spaces('[1,2]')]
749-
args += [f'--{prefix}num_list', arg_spaces('3,4')]
750-
args += [f'--{prefix}num_list', '5', f'--{prefix}num_list', '6']
748+
args = [f'--{prefix}str_list', arg_spaces('["1","2"]')]
749+
args += [f'--{prefix}num_list', arg_spaces('["1","2"]')]
750+
args += [f'--{prefix}str_list', arg_spaces('"3","4"')]
751+
args += [f'--{prefix}num_list', arg_spaces('"3","4"')]
752+
args += [f'--{prefix}str_list', '"5"', f'--{prefix}str_list', '"6"']
753+
args += [f'--{prefix}num_list', '"5"', f'--{prefix}num_list', '"6"']
751754
cfg = CliApp.run(Cfg, cli_args=args)
752755
expected = {
753756
'num_list': [1, 2, 3, 4, 5, 6],
754757
'obj_list': None,
755758
'union_list': None,
756-
'str_list': None,
759+
'str_list': ['1', '2', '3', '4', '5', '6'],
760+
}
761+
check_answer(cfg, prefix, expected)
762+
763+
args = [arg.replace('"', '') for arg in args]
764+
cfg = CliApp.run(Cfg, cli_args=args)
765+
expected = {
766+
'num_list': [1, 2, 3, 4, 5, 6],
767+
'obj_list': None,
768+
'union_list': None,
769+
'str_list': ['1', '2', '3', '4', '5', '6'],
757770
}
758771
check_answer(cfg, prefix, expected)
759772

0 commit comments

Comments
 (0)