Skip to content

Commit d344b6e

Browse files
committed
Further improved the implementation and removed protocol check caching
1 parent 26fd117 commit d344b6e

3 files changed

Lines changed: 153 additions & 129 deletions

File tree

docs/versionhistory.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ This library adheres to
1010
(`#465 <https://github.com/agronholm/typeguard/pull/465>`_)
1111
- Fixed basic support for intersection protocols
1212
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)
13+
- Fixed protocol checks running against the class of an instance and not the instance
14+
itself (this produced wrong results for non-method member checks)
1315

1416
**4.3.0** (2024-05-27)
1517

src/typeguard/_checkers.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
Union,
3434
)
3535
from unittest.mock import Mock
36-
from weakref import WeakKeyDictionary
3736

3837
import typing_extensions
3938

@@ -87,10 +86,6 @@
8786
if sys.version_info >= (3, 9):
8887
generic_alias_types += (types.GenericAlias,)
8988

90-
protocol_check_cache: WeakKeyDictionary[
91-
type[Any], dict[type[Any], tuple[Any, ...] | None]
92-
] = WeakKeyDictionary()
93-
9489
# Sentinel
9590
_missing = object()
9691

@@ -644,6 +639,33 @@ def check_signature_compatible(
644639
) -> None:
645640
subject_sig = inspect.signature(subject_callable)
646641
protocol_sig = inspect.signature(getattr(protocol, attrname))
642+
protocol_type: typing.Literal["instance", "class", "static"] = "instance"
643+
subject_type: typing.Literal["instance", "class", "static"] = "instance"
644+
645+
# Check if the protocol-side method is a class method or static method
646+
if attrname in protocol.__dict__:
647+
descriptor = protocol.__dict__[attrname]
648+
if isinstance(descriptor, staticmethod):
649+
protocol_type = "static"
650+
elif isinstance(descriptor, classmethod):
651+
protocol_type = "class"
652+
653+
# Check if the subject-side method is a class method or static method
654+
if inspect.ismethod(subject_callable) and inspect.isclass(
655+
subject_callable.__self__
656+
):
657+
subject_type = "class"
658+
elif not hasattr(subject_callable, "__self__"):
659+
subject_type = "static"
660+
661+
if protocol_type == "instance" and subject_type != "instance":
662+
raise TypeCheckError(
663+
f"should be an instance method but it's a {subject_type} method"
664+
)
665+
elif protocol_type != "instance" and subject_type == "instance":
666+
raise TypeCheckError(
667+
f"should be a {protocol_type} method but it's an instance method"
668+
)
647669

648670
expected_varargs = any(
649671
param
@@ -687,12 +709,9 @@ def check_signature_compatible(
687709
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
688710
]
689711

690-
# Remove the "self" parameter from methods
691-
if inspect.ismethod(subject_callable) or inspect.ismethoddescriptor(
692-
subject_callable
693-
):
712+
# Remove the "self" parameter from the protocol arguments to match
713+
if protocol_type == "instance":
694714
protocol_args.pop(0)
695-
subject_args.pop(0)
696715

697716
for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args):
698717
if protocol_arg is None:
@@ -763,65 +782,48 @@ def check_protocol(
763782
args: tuple[Any, ...],
764783
memo: TypeCheckMemo,
765784
) -> None:
766-
subject: type[Any] = value if isclass(value) else type(value)
785+
origin_annotations = typing.get_type_hints(origin_type)
786+
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
787+
if (annotation := origin_annotations.get(attrname)) is not None:
788+
try:
789+
subject_member = getattr(value, attrname)
790+
except AttributeError:
791+
raise TypeCheckError(
792+
f"is not compatible with the {origin_type.__qualname__} "
793+
f"protocol because it has no attribute named {attrname!r}"
794+
) from None
767795

768-
if subject in protocol_check_cache:
769-
result_map = protocol_check_cache[subject]
770-
if origin_type in result_map:
771-
if exc_args := result_map[origin_type]:
772-
raise TypeCheckError(*exc_args)
773-
else:
774-
return
796+
try:
797+
check_type_internal(subject_member, annotation, memo)
798+
except TypeCheckError as exc:
799+
raise TypeCheckError(
800+
f"is not compatible with the {origin_type.__qualname__} "
801+
f"protocol because its {attrname!r} attribute {exc}"
802+
) from None
803+
elif callable(getattr(origin_type, attrname)):
804+
try:
805+
subject_member = getattr(value, attrname)
806+
except AttributeError:
807+
raise TypeCheckError(
808+
f"is not compatible with the {origin_type.__qualname__} "
809+
f"protocol because it has no method named {attrname!r}"
810+
) from None
775811

776-
origin_annotations = typing.get_type_hints(origin_type)
777-
result_map = protocol_check_cache.setdefault(subject, {})
778-
try:
779-
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
780-
if (annotation := origin_annotations.get(attrname)) is not None:
781-
try:
782-
subject_member = getattr(subject, attrname)
783-
except AttributeError:
784-
raise TypeCheckError(
785-
f"is not compatible with the {origin_type.__qualname__} "
786-
f"protocol because it has no attribute named {attrname!r}"
787-
) from None
812+
if not callable(subject_member):
813+
raise TypeCheckError(
814+
f"is not compatible with the {origin_type.__qualname__} "
815+
f"protocol because its {attrname!r} attribute is not a callable"
816+
)
788817

789-
try:
790-
check_type_internal(subject_member, annotation, memo)
791-
except TypeCheckError as exc:
792-
raise TypeCheckError(
793-
f"is not compatible with the {origin_type.__qualname__} "
794-
f"protocol because its {attrname!r} attribute {exc}"
795-
) from None
796-
elif callable(getattr(origin_type, attrname)):
797-
try:
798-
subject_member = getattr(subject, attrname)
799-
except AttributeError:
800-
raise TypeCheckError(
801-
f"is not compatible with the {origin_type.__qualname__} "
802-
f"protocol because it has no method named {attrname!r}"
803-
) from None
804-
805-
if not callable(subject_member):
806-
raise TypeCheckError(
807-
f"is not compatible with the {origin_type.__qualname__} "
808-
f"protocol because its {attrname!r} attribute is not a callable"
809-
)
810-
811-
# TODO: implement assignability checks for parameter and return value
812-
# annotations
813-
try:
814-
check_signature_compatible(subject_member, origin_type, attrname)
815-
except TypeCheckError as exc:
816-
raise TypeCheckError(
817-
f"is not compatible with the {origin_type.__qualname__} "
818-
f"protocol because its {attrname!r} method {exc}"
819-
) from None
820-
except TypeCheckError as exc:
821-
result_map[origin_type] = exc.args
822-
raise
823-
else:
824-
result_map[origin_type] = None
818+
# TODO: implement assignability checks for parameter and return value
819+
# annotations
820+
try:
821+
check_signature_compatible(subject_member, origin_type, attrname)
822+
except TypeCheckError as exc:
823+
raise TypeCheckError(
824+
f"is not compatible with the {origin_type.__qualname__} "
825+
f"protocol because its {attrname!r} method {exc}"
826+
) from None
825827

826828

827829
def check_byteslike(

tests/test_checkers.py

Lines changed: 83 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,20 +1137,23 @@ def my_static_method(cls, x: int, y: str) -> None:
11371137
def my_class_method(x: int, y: str) -> None:
11381138
pass
11391139

1140-
for _ in range(2): # Makes sure that the cache is also exercised
1141-
check_type(Foo(), MyProtocol)
1140+
check_type(Foo(), MyProtocol)
11421141

1143-
def test_missing_member(self) -> None:
1142+
@pytest.mark.parametrize("has_member", [True, False])
1143+
def test_member_checks(self, has_member: bool) -> None:
11441144
class MyProtocol(Protocol):
11451145
member: int
11461146

11471147
class Foo:
1148-
pass
1148+
def __init__(self, member: int):
1149+
if member:
1150+
self.member = member
11491151

1150-
obj = Foo()
1151-
for _ in range(2): # Makes sure that the cache is also exercised
1152-
pytest.raises(TypeCheckError, check_type, obj, MyProtocol).match(
1153-
f"^{qualified_name(obj)} is not compatible with the "
1152+
if has_member:
1153+
check_type(Foo(1), MyProtocol)
1154+
else:
1155+
pytest.raises(TypeCheckError, check_type, Foo(0), MyProtocol).match(
1156+
f"^{qualified_name(Foo)} is not compatible with the "
11541157
f"{MyProtocol.__qualname__} protocol because it has no attribute named "
11551158
f"'member'"
11561159
)
@@ -1163,12 +1166,11 @@ def meth(self) -> None:
11631166
class Foo:
11641167
pass
11651168

1166-
for _ in range(2): # Makes sure that the cache is also exercised
1167-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1168-
f"^{qualified_name(Foo)} is not compatible with the "
1169-
f"{MyProtocol.__qualname__} protocol because it has no method named "
1170-
f"'meth'"
1171-
)
1169+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1170+
f"^{qualified_name(Foo)} is not compatible with the "
1171+
f"{MyProtocol.__qualname__} protocol because it has no method named "
1172+
f"'meth'"
1173+
)
11721174

11731175
def test_too_many_posargs(self) -> None:
11741176
class MyProtocol(Protocol):
@@ -1179,13 +1181,11 @@ class Foo:
11791181
def meth(self, x: str) -> None:
11801182
pass
11811183

1182-
obj = Foo()
1183-
for _ in range(2): # Makes sure that the cache is also exercised
1184-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1185-
f"^{qualified_name(obj)} is not compatible with the "
1186-
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
1187-
f"many mandatory positional arguments"
1188-
)
1184+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1185+
f"^{qualified_name(Foo)} is not compatible with the "
1186+
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
1187+
f"many mandatory positional arguments"
1188+
)
11891189

11901190
def test_wrong_posarg_name(self) -> None:
11911191
class MyProtocol(Protocol):
@@ -1196,13 +1196,11 @@ class Foo:
11961196
def meth(self, y: str) -> None:
11971197
pass
11981198

1199-
obj = Foo()
1200-
for _ in range(2): # Makes sure that the cache is also exercised
1201-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1202-
rf"^{qualified_name(obj)} is not compatible with the "
1203-
rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a "
1204-
rf"positional argument \(y\) that should be named 'x' at this position"
1205-
)
1199+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1200+
rf"^{qualified_name(Foo)} is not compatible with the "
1201+
rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a "
1202+
rf"positional argument \(y\) that should be named 'x' at this position"
1203+
)
12061204

12071205
def test_too_few_posargs(self) -> None:
12081206
class MyProtocol(Protocol):
@@ -1213,13 +1211,11 @@ class Foo:
12131211
def meth(self) -> None:
12141212
pass
12151213

1216-
obj = Foo()
1217-
for _ in range(2): # Makes sure that the cache is also exercised
1218-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1219-
f"^{qualified_name(obj)} is not compatible with the "
1220-
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
1221-
f"few positional arguments"
1222-
)
1214+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1215+
f"^{qualified_name(Foo)} is not compatible with the "
1216+
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
1217+
f"few positional arguments"
1218+
)
12231219

12241220
def test_no_varargs(self) -> None:
12251221
class MyProtocol(Protocol):
@@ -1230,13 +1226,11 @@ class Foo:
12301226
def meth(self) -> None:
12311227
pass
12321228

1233-
obj = Foo()
1234-
for _ in range(2): # Makes sure that the cache is also exercised
1235-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1236-
f"^{qualified_name(obj)} is not compatible with the "
1237-
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1238-
f"accept variable positional arguments but doesn't"
1239-
)
1229+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1230+
f"^{qualified_name(Foo)} is not compatible with the "
1231+
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1232+
f"accept variable positional arguments but doesn't"
1233+
)
12401234

12411235
def test_no_kwargs(self) -> None:
12421236
class MyProtocol(Protocol):
@@ -1247,13 +1241,11 @@ class Foo:
12471241
def meth(self) -> None:
12481242
pass
12491243

1250-
obj = Foo()
1251-
for _ in range(2): # Makes sure that the cache is also exercised
1252-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1253-
f"^{qualified_name(obj)} is not compatible with the "
1254-
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1255-
f"accept variable keyword arguments but doesn't"
1256-
)
1244+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1245+
f"^{qualified_name(Foo)} is not compatible with the "
1246+
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1247+
f"accept variable keyword arguments but doesn't"
1248+
)
12571249

12581250
def test_missing_kwarg(self) -> None:
12591251
class MyProtocol(Protocol):
@@ -1264,13 +1256,11 @@ class Foo:
12641256
def meth(self) -> None:
12651257
pass
12661258

1267-
obj = Foo()
1268-
for _ in range(2): # Makes sure that the cache is also exercised
1269-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1270-
f"^{qualified_name(obj)} is not compatible with the "
1271-
f"{MyProtocol.__qualname__} protocol because its 'meth' method is "
1272-
f"missing keyword-only arguments: x"
1273-
)
1259+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1260+
f"^{qualified_name(Foo)} is not compatible with the "
1261+
f"{MyProtocol.__qualname__} protocol because its 'meth' method is "
1262+
f"missing keyword-only arguments: x"
1263+
)
12741264

12751265
def test_extra_kwarg(self) -> None:
12761266
class MyProtocol(Protocol):
@@ -1281,13 +1271,43 @@ class Foo:
12811271
def meth(self, *, x: str) -> None:
12821272
pass
12831273

1284-
obj = Foo()
1285-
for _ in range(2): # Makes sure that the cache is also exercised
1286-
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1287-
f"^{qualified_name(obj)} is not compatible with the "
1288-
f"{MyProtocol.__qualname__} protocol because its 'meth' method has "
1289-
f"mandatory keyword-only arguments not present in the protocol: x"
1290-
)
1274+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1275+
f"^{qualified_name(Foo)} is not compatible with the "
1276+
f"{MyProtocol.__qualname__} protocol because its 'meth' method has "
1277+
f"mandatory keyword-only arguments not present in the protocol: x"
1278+
)
1279+
1280+
def test_instance_staticmethod_mismatch(self) -> None:
1281+
class MyProtocol(Protocol):
1282+
@staticmethod
1283+
def meth() -> None:
1284+
pass
1285+
1286+
class Foo:
1287+
def meth(self) -> None:
1288+
pass
1289+
1290+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1291+
f"^{qualified_name(Foo)} is not compatible with the "
1292+
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1293+
f"be a static method but it's an instance method"
1294+
)
1295+
1296+
def test_instance_classmethod_mismatch(self) -> None:
1297+
class MyProtocol(Protocol):
1298+
@classmethod
1299+
def meth(cls) -> None:
1300+
pass
1301+
1302+
class Foo:
1303+
def meth(self) -> None:
1304+
pass
1305+
1306+
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
1307+
f"^{qualified_name(Foo)} is not compatible with the "
1308+
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
1309+
f"be a class method but it's an instance method"
1310+
)
12911311

12921312

12931313
class TestRecursiveType:

0 commit comments

Comments
 (0)