Skip to content

Commit 42cdc09

Browse files
committed
Use get_protocol_members in protocol checking
This changes `check_protocol()` to make use of `get_protocol_members` from typing-extensions. This allows removing an existing hard-coded exclusion list for attributes existing on Protocol, but also handles the cases `__orig_bases__` and `__weakref__` that was breaking when checking intersecting protocols (a subclass of two or more protocols). This has the effect of turning some false positives into true negatives, but it also leaves some false negatives. To make that clear, xfail test cases are added for the resulting false negatives.
1 parent c72b675 commit 42cdc09

3 files changed

Lines changed: 96 additions & 14 deletions

File tree

docs/versionhistory.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ Version history
44
This library adheres to
55
`Semantic Versioning 2.0 <https://semver.org/#semantic-versioning-200>`_.
66

7+
**UNRELEASED**
8+
9+
- Fixed basic support for intersection protocols
10+
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)
11+
712
**4.3.0** (2024-05-27)
813

914
- Added support for checking against static protocols

src/typeguard/_checkers.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -654,19 +654,13 @@ def check_protocol(
654654
else:
655655
return
656656

657-
# Collect a set of methods and non-method attributes present in the protocol
658-
ignored_attrs = set(dir(typing.Protocol)) | {
659-
"__annotations__",
660-
"__non_callable_proto_members__",
661-
}
662657
expected_methods: dict[str, tuple[Any, Any]] = {}
663658
expected_noncallable_members: dict[str, Any] = {}
664-
for attrname in dir(origin_type):
665-
# Skip attributes present in typing.Protocol
666-
if attrname in ignored_attrs:
667-
continue
659+
origin_annotations = typing.get_type_hints(origin_type)
660+
661+
for attrname in typing_extensions.get_protocol_members(origin_type):
662+
member = getattr(origin_type, attrname, None)
668663

669-
member = getattr(origin_type, attrname)
670664
if callable(member):
671665
signature = inspect.signature(member)
672666
argtypes = [
@@ -681,10 +675,10 @@ def check_protocol(
681675
)
682676
expected_methods[attrname] = argtypes, return_annotation
683677
else:
684-
expected_noncallable_members[attrname] = member
685-
686-
for attrname, annotation in typing.get_type_hints(origin_type).items():
687-
expected_noncallable_members[attrname] = annotation
678+
try:
679+
expected_noncallable_members[attrname] = origin_annotations[attrname]
680+
except KeyError:
681+
expected_noncallable_members[attrname] = member
688682

689683
subject_annotations = typing.get_type_hints(subject)
690684

tests/test_checkers.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
Dict,
1717
ForwardRef,
1818
FrozenSet,
19+
Iterable,
1920
Iterator,
2021
List,
2122
Literal,
2223
Mapping,
2324
MutableMapping,
2425
Optional,
26+
Protocol,
2527
Sequence,
2628
Set,
29+
Sized,
2730
TextIO,
2831
Tuple,
2932
Type,
@@ -995,6 +998,86 @@ def test_text_real_file(self, tmp_path: Path):
995998
check_type(f, TextIO)
996999

9971000

1001+
class TestIntersectingProtocol:
1002+
SIT = TypeVar("SIT", covariant=True)
1003+
1004+
class SizedIterable(
1005+
Sized,
1006+
Iterable[SIT],
1007+
Protocol[SIT],
1008+
): ...
1009+
1010+
@pytest.mark.parametrize(
1011+
"subject, predicate_type",
1012+
(
1013+
pytest.param(
1014+
(),
1015+
SizedIterable,
1016+
id="empty_tuple_unspecialized",
1017+
),
1018+
pytest.param(
1019+
range(2),
1020+
SizedIterable,
1021+
id="range",
1022+
),
1023+
pytest.param(
1024+
(),
1025+
SizedIterable[int],
1026+
id="empty_tuple_int_specialized",
1027+
),
1028+
pytest.param(
1029+
(1, 2, 3),
1030+
SizedIterable[int],
1031+
id="tuple_int_specialized",
1032+
),
1033+
pytest.param(
1034+
("1", "2", "3"),
1035+
SizedIterable[str],
1036+
id="tuple_str_specialized",
1037+
),
1038+
),
1039+
)
1040+
def test_valid_member_passes(self, subject: object, predicate_type: type) -> None:
1041+
for _ in range(2): # Makes sure that the cache is also exercised
1042+
check_type(subject, predicate_type)
1043+
1044+
xfail_nested_protocol_checks = pytest.mark.xfail(
1045+
reason="false negative due to missing support for nested protocol checks",
1046+
)
1047+
1048+
@pytest.mark.parametrize(
1049+
"subject, predicate_type",
1050+
(
1051+
pytest.param(
1052+
(1 for _ in ()),
1053+
SizedIterable,
1054+
id="generator",
1055+
),
1056+
pytest.param(
1057+
range(2),
1058+
SizedIterable[str],
1059+
marks=xfail_nested_protocol_checks,
1060+
id="range_str_specialized",
1061+
),
1062+
pytest.param(
1063+
(1, 2, 3),
1064+
SizedIterable[str],
1065+
marks=xfail_nested_protocol_checks,
1066+
id="int_tuple_str_specialized",
1067+
),
1068+
pytest.param(
1069+
("1", "2", "3"),
1070+
SizedIterable[int],
1071+
marks=xfail_nested_protocol_checks,
1072+
id="str_tuple_int_specialized",
1073+
),
1074+
),
1075+
)
1076+
def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None:
1077+
with pytest.raises(TypeCheckError):
1078+
check_type(subject, predicate_type)
1079+
1080+
9981081
@pytest.mark.parametrize(
9991082
"instantiate, annotation",
10001083
[

0 commit comments

Comments
 (0)