Skip to content

Commit 84ad04c

Browse files
committed
Fix UnionType handling as with Union for Optional values
1 parent b9515e8 commit 84ad04c

3 files changed

Lines changed: 70 additions & 13 deletions

File tree

tests/test_type_conversion.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from enum import Enum
23
from pathlib import Path
34
from typing import List, Optional, Tuple
@@ -28,6 +29,28 @@ def opt(user: Optional[str] = None):
2829
assert "User: Camila" in result.output
2930

3031

32+
@pytest.mark.skipif(
33+
sys.version_info < (3, 10), reason="The | operator for types was new in 3.10"
34+
)
35+
def test_union_type_optional():
36+
app = typer.Typer()
37+
38+
@app.command()
39+
def opt(user: str | None = None):
40+
if user:
41+
print(f"User: {user}")
42+
else:
43+
print("No user")
44+
45+
result = runner.invoke(app)
46+
assert result.exit_code == 0
47+
assert "No user" in result.output
48+
49+
result = runner.invoke(app, ["--user", "Camila"])
50+
assert result.exit_code == 0
51+
assert "User: Camila" in result.output
52+
53+
3154
def test_no_type():
3255
app = typer.Typer()
3356

typer/_compat_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,37 @@
1+
import sys
2+
from typing import Union
3+
14
import click
25

6+
if sys.version_info >= (3, 8):
7+
from typing import get_args as _get_args
8+
from typing import get_origin as _get_origin
9+
elif sys.version_info >= (3, 7):
10+
from typing_extensions import get_args as _get_args
11+
from typing_extensions import get_origin as _get_origin
12+
else:
13+
# These methods do not handle all the same details as the imported ones.
14+
# However on Python 3.6 they should be sufficient.
15+
# typer <= 0.7.0 used this implementation.
16+
17+
def _get_origin(arg):
18+
return getattr(arg, "__origin__", None)
19+
20+
def _get_args(arg):
21+
return getattr(arg, "__args__", None)
22+
23+
24+
# Assigning variables to mark them as exported with mypy
25+
get_origin = _get_origin
26+
get_args = _get_args
27+
28+
if sys.version_info >= (3, 10):
29+
from types import UnionType
30+
31+
UNION_TYPES = (UnionType, Union)
32+
else:
33+
UNION_TYPES = (Union,)
34+
335

436
def _get_click_major() -> int:
537
return int(click.__version__.split(".")[0])

typer/main.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import click
1515

16+
from ._compat_utils import UNION_TYPES, get_args, get_origin
1617
from .completion import get_completion_inspect_parameters
1718
from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption
1819
from .models import (
@@ -810,30 +811,31 @@ def get_click_param(
810811
is_tuple = False
811812
parameter_type: Any = None
812813
is_flag = None
813-
origin = getattr(main_type, "__origin__", None)
814+
origin = get_origin(main_type)
815+
814816
if origin is not None:
815-
# Handle Optional[SomeType]
816-
if origin is Union:
817+
# Handle SomeType | None and Optional[SomeType]
818+
if origin in UNION_TYPES:
817819
types = []
818-
for type_ in main_type.__args__:
820+
for type_ in get_args(main_type):
819821
if type_ is NoneType:
820822
continue
821823
types.append(type_)
822824
assert len(types) == 1, "Typer Currently doesn't support Union types"
823825
main_type = types[0]
824-
origin = getattr(main_type, "__origin__", None)
826+
origin = get_origin(main_type)
825827
# Handle Tuples and Lists
826828
if lenient_issubclass(origin, List):
827-
main_type = main_type.__args__[0]
828-
assert not getattr(
829-
main_type, "__origin__", None
829+
main_type = get_args(main_type)[0]
830+
assert not get_origin(
831+
main_type
830832
), "List types with complex sub-types are not currently supported"
831833
is_list = True
832834
elif lenient_issubclass(origin, Tuple): # type: ignore
833835
types = []
834-
for type_ in main_type.__args__:
835-
assert not getattr(
836-
type_, "__origin__", None
836+
for type_ in get_args(main_type):
837+
assert not get_origin(
838+
type_
837839
), "Tuple types with complex sub-types are not currently supported"
838840
types.append(
839841
get_click_type(annotation=type_, parameter_info=parameter_info)
@@ -848,7 +850,7 @@ def get_click_param(
848850
if is_list:
849851
convertor = generate_list_convertor(convertor)
850852
if is_tuple:
851-
convertor = generate_tuple_convertor(main_type.__args__)
853+
convertor = generate_tuple_convertor(get_args(main_type))
852854
if isinstance(parameter_info, OptionInfo):
853855
if main_type is bool and not (parameter_info.is_flag is False):
854856
is_flag = True
@@ -1002,7 +1004,7 @@ def get_param_completion(
10021004
incomplete_name = None
10031005
unassigned_params = [param for param in parameters.values()]
10041006
for param_sig in unassigned_params[:]:
1005-
origin = getattr(param_sig.annotation, "__origin__", None)
1007+
origin = get_origin(param_sig.annotation)
10061008
if lenient_issubclass(param_sig.annotation, click.Context):
10071009
ctx_name = param_sig.name
10081010
unassigned_params.remove(param_sig)

0 commit comments

Comments
 (0)