Skip to content

Commit b72794d

Browse files
authored
Added proper Protocol method signature checking (#496)
It's not good enough to pretend we can use `check_callable()` to check method signature compatibility. Fixes #465.
1 parent afad2c7 commit b72794d

5 files changed

Lines changed: 375 additions & 178 deletions

File tree

docs/features.rst

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,8 @@ As of version 4.3.0, Typeguard can check instances and classes against Protocols
6565
regardless of whether they were annotated with
6666
:func:`@runtime_checkable <typing.runtime_checkable>`.
6767

68-
There are several limitations on the checks performed, however:
69-
70-
* For non-callable members, only presence is checked for; no type compatibility checks
71-
are performed
72-
* For methods, only the number of positional arguments are checked against, so any added
73-
keyword-only arguments without defaults don't currently trip the checker
74-
* Likewise, argument types are not checked for compatibility
68+
The only current limitation is that argument annotations are not checked for
69+
compatibility, however this should be covered by static type checkers pretty well.
7570

7671
Special considerations for ``if TYPE_CHECKING:``
7772
------------------------------------------------

docs/versionhistory.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ This library adheres to
66

77
**UNRELEASED**
88

9+
- Added proper checking for method signatures in protocol checks
10+
(`#465 <https://github.com/agronholm/typeguard/pull/465>`_)
911
- Fixed basic support for intersection protocols
1012
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)
13+
- Fixed protocol checks running against the class of an instance and not the instance
14+
itself (this produced wrong results for non-method member checks)
1115

1216
**4.3.0** (2024-05-27)
1317

src/typeguard/_checkers.py

Lines changed: 173 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from enum import Enum
1010
from inspect import Parameter, isclass, isfunction
1111
from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase
12+
from itertools import zip_longest
1213
from textwrap import indent
1314
from typing import (
1415
IO,
@@ -32,7 +33,6 @@
3233
Union,
3334
)
3435
from unittest.mock import Mock
35-
from weakref import WeakKeyDictionary
3636

3737
import typing_extensions
3838

@@ -86,10 +86,6 @@
8686
if sys.version_info >= (3, 9):
8787
generic_alias_types += (types.GenericAlias,)
8888

89-
protocol_check_cache: WeakKeyDictionary[
90-
type[Any], dict[type[Any], TypeCheckError | None]
91-
] = WeakKeyDictionary()
92-
9389
# Sentinel
9490
_missing = object()
9591

@@ -638,96 +634,196 @@ def check_io(
638634
raise TypeCheckError("is not an I/O object")
639635

640636

641-
def check_protocol(
642-
value: Any,
643-
origin_type: Any,
644-
args: tuple[Any, ...],
645-
memo: TypeCheckMemo,
637+
def check_signature_compatible(
638+
subject_callable: Callable[..., Any], protocol: type, attrname: str
646639
) -> None:
647-
subject: type[Any] = value if isclass(value) else type(value)
640+
subject_sig = inspect.signature(subject_callable)
641+
protocol_sig = inspect.signature(getattr(protocol, attrname))
642+
protocol_type: typing.Literal["instance", "class", "static"] = "instance"
643+
subject_type: typing.Literal["instance", "class", "static"] = "instance"
644+
645+
# Check if the protocol-side method is a class method or static method
646+
if attrname in protocol.__dict__:
647+
descriptor = protocol.__dict__[attrname]
648+
if isinstance(descriptor, staticmethod):
649+
protocol_type = "static"
650+
elif isinstance(descriptor, classmethod):
651+
protocol_type = "class"
652+
653+
# Check if the subject-side method is a class method or static method
654+
if inspect.ismethod(subject_callable) and inspect.isclass(
655+
subject_callable.__self__
656+
):
657+
subject_type = "class"
658+
elif not hasattr(subject_callable, "__self__"):
659+
subject_type = "static"
648660

649-
if subject in protocol_check_cache:
650-
result_map = protocol_check_cache[subject]
651-
if origin_type in result_map:
652-
if exc := result_map[origin_type]:
653-
raise exc
654-
else:
655-
return
661+
if protocol_type == "instance" and subject_type != "instance":
662+
raise TypeCheckError(
663+
f"should be an instance method but it's a {subject_type} method"
664+
)
665+
elif protocol_type != "instance" and subject_type == "instance":
666+
raise TypeCheckError(
667+
f"should be a {protocol_type} method but it's an instance method"
668+
)
656669

657-
expected_methods: dict[str, tuple[Any, Any]] = {}
658-
expected_noncallable_members: dict[str, Any] = {}
659-
origin_annotations = typing.get_type_hints(origin_type)
670+
expected_varargs = any(
671+
param
672+
for param in protocol_sig.parameters.values()
673+
if param.kind is Parameter.VAR_POSITIONAL
674+
)
675+
has_varargs = any(
676+
param
677+
for param in subject_sig.parameters.values()
678+
if param.kind is Parameter.VAR_POSITIONAL
679+
)
680+
if expected_varargs and not has_varargs:
681+
raise TypeCheckError("should accept variable positional arguments but doesn't")
682+
683+
protocol_has_varkwargs = any(
684+
param
685+
for param in protocol_sig.parameters.values()
686+
if param.kind is Parameter.VAR_KEYWORD
687+
)
688+
subject_has_varkwargs = any(
689+
param
690+
for param in subject_sig.parameters.values()
691+
if param.kind is Parameter.VAR_KEYWORD
692+
)
693+
if protocol_has_varkwargs and not subject_has_varkwargs:
694+
raise TypeCheckError("should accept variable keyword arguments but doesn't")
695+
696+
# Check that the callable has at least the expect amount of positional-only
697+
# arguments (and no extra positional-only arguments without default values)
698+
if not has_varargs:
699+
protocol_args = [
700+
param
701+
for param in protocol_sig.parameters.values()
702+
if param.kind
703+
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
704+
]
705+
subject_args = [
706+
param
707+
for param in subject_sig.parameters.values()
708+
if param.kind
709+
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
710+
]
711+
712+
# Remove the "self" parameter from the protocol arguments to match
713+
if protocol_type == "instance":
714+
protocol_args.pop(0)
715+
716+
for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args):
717+
if protocol_arg is None:
718+
if subject_arg.default is Parameter.empty:
719+
raise TypeCheckError("has too many mandatory positional arguments")
720+
721+
break
722+
723+
if subject_arg is None:
724+
raise TypeCheckError("has too few positional arguments")
725+
726+
if (
727+
protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD
728+
and subject_arg.kind is Parameter.POSITIONAL_ONLY
729+
):
730+
raise TypeCheckError(
731+
f"has an argument ({subject_arg.name}) that should not be "
732+
f"positional-only"
733+
)
734+
735+
if (
736+
protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD
737+
and protocol_arg.name != subject_arg.name
738+
):
739+
raise TypeCheckError(
740+
f"has a positional argument ({subject_arg.name}) that should be "
741+
f"named {protocol_arg.name!r} at this position"
742+
)
660743

661-
for attrname in typing_extensions.get_protocol_members(origin_type):
662-
member = getattr(origin_type, attrname, None)
663-
664-
if callable(member):
665-
signature = inspect.signature(member)
666-
argtypes = [
667-
(p.annotation if p.annotation is not Parameter.empty else Any)
668-
for p in signature.parameters.values()
669-
if p.kind is not Parameter.KEYWORD_ONLY
670-
] or Ellipsis
671-
return_annotation = (
672-
signature.return_annotation
673-
if signature.return_annotation is not Parameter.empty
674-
else Any
744+
protocol_kwonlyargs = {
745+
param.name: param
746+
for param in protocol_sig.parameters.values()
747+
if param.kind is Parameter.KEYWORD_ONLY
748+
}
749+
subject_kwonlyargs = {
750+
param.name: param
751+
for param in subject_sig.parameters.values()
752+
if param.kind is Parameter.KEYWORD_ONLY
753+
}
754+
if not subject_has_varkwargs:
755+
# Check that the signature has at least the required keyword-only arguments, and
756+
# no extra mandatory keyword-only arguments
757+
if missing_kwonlyargs := [
758+
param.name
759+
for param in protocol_kwonlyargs.values()
760+
if param.name not in subject_kwonlyargs
761+
]:
762+
raise TypeCheckError(
763+
"is missing keyword-only arguments: " + ", ".join(missing_kwonlyargs)
675764
)
676-
expected_methods[attrname] = argtypes, return_annotation
677-
else:
678-
try:
679-
expected_noncallable_members[attrname] = origin_annotations[attrname]
680-
except KeyError:
681-
expected_noncallable_members[attrname] = member
682765

683-
subject_annotations = typing.get_type_hints(subject)
766+
if not protocol_has_varkwargs:
767+
if extra_kwonlyargs := [
768+
param.name
769+
for param in subject_kwonlyargs.values()
770+
if param.default is Parameter.empty
771+
and param.name not in protocol_kwonlyargs
772+
]:
773+
raise TypeCheckError(
774+
"has mandatory keyword-only arguments not present in the protocol: "
775+
+ ", ".join(extra_kwonlyargs)
776+
)
684777

685-
# Check that all required methods are present and their signatures are compatible
686-
result_map = protocol_check_cache.setdefault(subject, {})
687-
try:
688-
for attrname, callable_args in expected_methods.items():
778+
779+
def check_protocol(
780+
value: Any,
781+
origin_type: Any,
782+
args: tuple[Any, ...],
783+
memo: TypeCheckMemo,
784+
) -> None:
785+
origin_annotations = typing.get_type_hints(origin_type)
786+
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
787+
if (annotation := origin_annotations.get(attrname)) is not None:
689788
try:
690-
method = getattr(subject, attrname)
789+
subject_member = getattr(value, attrname)
691790
except AttributeError:
692-
if attrname in subject_annotations:
693-
raise TypeCheckError(
694-
f"is not compatible with the {origin_type.__qualname__} protocol "
695-
f"because its {attrname!r} attribute is not a method"
696-
) from None
697-
else:
698-
raise TypeCheckError(
699-
f"is not compatible with the {origin_type.__qualname__} protocol "
700-
f"because it has no method named {attrname!r}"
701-
) from None
702-
703-
if not callable(method):
704791
raise TypeCheckError(
705-
f"is not compatible with the {origin_type.__qualname__} protocol "
706-
f"because its {attrname!r} attribute is not a callable"
707-
)
792+
f"is not compatible with the {origin_type.__qualname__} "
793+
f"protocol because it has no attribute named {attrname!r}"
794+
) from None
708795

709-
# TODO: raise exception on added keyword-only arguments without defaults
710796
try:
711-
check_callable(method, Callable, callable_args, memo)
797+
check_type_internal(subject_member, annotation, memo)
712798
except TypeCheckError as exc:
713799
raise TypeCheckError(
714-
f"is not compatible with the {origin_type.__qualname__} protocol "
715-
f"because its {attrname!r} method {exc}"
800+
f"is not compatible with the {origin_type.__qualname__} "
801+
f"protocol because its {attrname!r} attribute {exc}"
802+
) from None
803+
elif callable(getattr(origin_type, attrname)):
804+
try:
805+
subject_member = getattr(value, attrname)
806+
except AttributeError:
807+
raise TypeCheckError(
808+
f"is not compatible with the {origin_type.__qualname__} "
809+
f"protocol because it has no method named {attrname!r}"
716810
) from None
717811

718-
# Check that all required non-callable members are present
719-
for attrname in expected_noncallable_members:
720-
# TODO: implement assignability checks for non-callable members
721-
if attrname not in subject_annotations and not hasattr(subject, attrname):
812+
if not callable(subject_member):
722813
raise TypeCheckError(
723-
f"is not compatible with the {origin_type.__qualname__} protocol "
724-
f"because it has no attribute named {attrname!r}"
814+
f"is not compatible with the {origin_type.__qualname__} "
815+
f"protocol because its {attrname!r} attribute is not a callable"
725816
)
726-
except TypeCheckError as exc:
727-
result_map[origin_type] = exc
728-
raise
729-
else:
730-
result_map[origin_type] = None
817+
818+
# TODO: implement assignability checks for parameter and return value
819+
# annotations
820+
try:
821+
check_signature_compatible(subject_member, origin_type, attrname)
822+
except TypeCheckError as exc:
823+
raise TypeCheckError(
824+
f"is not compatible with the {origin_type.__qualname__} "
825+
f"protocol because its {attrname!r} method {exc}"
826+
) from None
731827

732828

733829
def check_byteslike(

tests/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
List,
77
NamedTuple,
88
NewType,
9-
Protocol,
109
TypeVar,
1110
Union,
12-
runtime_checkable,
1311
)
1412

1513
T_Foo = TypeVar("T_Foo")
@@ -44,16 +42,3 @@ class Parent:
4442
class Child(Parent):
4543
def method(self, a: int) -> None:
4644
pass
47-
48-
49-
class StaticProtocol(Protocol):
50-
member: int
51-
52-
def meth(self, x: str) -> None: ...
53-
54-
55-
@runtime_checkable
56-
class RuntimeProtocol(Protocol):
57-
member: int
58-
59-
def meth(self, x: str) -> None: ...

0 commit comments

Comments
 (0)