|
9 | 9 | from enum import Enum |
10 | 10 | from inspect import Parameter, isclass, isfunction |
11 | 11 | from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase |
| 12 | +from itertools import zip_longest |
12 | 13 | from textwrap import indent |
13 | 14 | from typing import ( |
14 | 15 | IO, |
|
32 | 33 | Union, |
33 | 34 | ) |
34 | 35 | from unittest.mock import Mock |
35 | | -from weakref import WeakKeyDictionary |
36 | 36 |
|
37 | 37 | import typing_extensions |
38 | 38 |
|
|
86 | 86 | if sys.version_info >= (3, 9): |
87 | 87 | generic_alias_types += (types.GenericAlias,) |
88 | 88 |
|
89 | | -protocol_check_cache: WeakKeyDictionary[ |
90 | | - type[Any], dict[type[Any], TypeCheckError | None] |
91 | | -] = WeakKeyDictionary() |
92 | | - |
93 | 89 | # Sentinel |
94 | 90 | _missing = object() |
95 | 91 |
|
@@ -638,96 +634,196 @@ def check_io( |
638 | 634 | raise TypeCheckError("is not an I/O object") |
639 | 635 |
|
640 | 636 |
|
641 | | -def check_protocol( |
642 | | - value: Any, |
643 | | - origin_type: Any, |
644 | | - args: tuple[Any, ...], |
645 | | - memo: TypeCheckMemo, |
| 637 | +def check_signature_compatible( |
| 638 | + subject_callable: Callable[..., Any], protocol: type, attrname: str |
646 | 639 | ) -> None: |
647 | | - subject: type[Any] = value if isclass(value) else type(value) |
| 640 | + subject_sig = inspect.signature(subject_callable) |
| 641 | + protocol_sig = inspect.signature(getattr(protocol, attrname)) |
| 642 | + protocol_type: typing.Literal["instance", "class", "static"] = "instance" |
| 643 | + subject_type: typing.Literal["instance", "class", "static"] = "instance" |
| 644 | + |
| 645 | + # Check if the protocol-side method is a class method or static method |
| 646 | + if attrname in protocol.__dict__: |
| 647 | + descriptor = protocol.__dict__[attrname] |
| 648 | + if isinstance(descriptor, staticmethod): |
| 649 | + protocol_type = "static" |
| 650 | + elif isinstance(descriptor, classmethod): |
| 651 | + protocol_type = "class" |
| 652 | + |
| 653 | + # Check if the subject-side method is a class method or static method |
| 654 | + if inspect.ismethod(subject_callable) and inspect.isclass( |
| 655 | + subject_callable.__self__ |
| 656 | + ): |
| 657 | + subject_type = "class" |
| 658 | + elif not hasattr(subject_callable, "__self__"): |
| 659 | + subject_type = "static" |
648 | 660 |
|
649 | | - if subject in protocol_check_cache: |
650 | | - result_map = protocol_check_cache[subject] |
651 | | - if origin_type in result_map: |
652 | | - if exc := result_map[origin_type]: |
653 | | - raise exc |
654 | | - else: |
655 | | - return |
| 661 | + if protocol_type == "instance" and subject_type != "instance": |
| 662 | + raise TypeCheckError( |
| 663 | + f"should be an instance method but it's a {subject_type} method" |
| 664 | + ) |
| 665 | + elif protocol_type != "instance" and subject_type == "instance": |
| 666 | + raise TypeCheckError( |
| 667 | + f"should be a {protocol_type} method but it's an instance method" |
| 668 | + ) |
656 | 669 |
|
657 | | - expected_methods: dict[str, tuple[Any, Any]] = {} |
658 | | - expected_noncallable_members: dict[str, Any] = {} |
659 | | - origin_annotations = typing.get_type_hints(origin_type) |
| 670 | + expected_varargs = any( |
| 671 | + param |
| 672 | + for param in protocol_sig.parameters.values() |
| 673 | + if param.kind is Parameter.VAR_POSITIONAL |
| 674 | + ) |
| 675 | + has_varargs = any( |
| 676 | + param |
| 677 | + for param in subject_sig.parameters.values() |
| 678 | + if param.kind is Parameter.VAR_POSITIONAL |
| 679 | + ) |
| 680 | + if expected_varargs and not has_varargs: |
| 681 | + raise TypeCheckError("should accept variable positional arguments but doesn't") |
| 682 | + |
| 683 | + protocol_has_varkwargs = any( |
| 684 | + param |
| 685 | + for param in protocol_sig.parameters.values() |
| 686 | + if param.kind is Parameter.VAR_KEYWORD |
| 687 | + ) |
| 688 | + subject_has_varkwargs = any( |
| 689 | + param |
| 690 | + for param in subject_sig.parameters.values() |
| 691 | + if param.kind is Parameter.VAR_KEYWORD |
| 692 | + ) |
| 693 | + if protocol_has_varkwargs and not subject_has_varkwargs: |
| 694 | + raise TypeCheckError("should accept variable keyword arguments but doesn't") |
| 695 | + |
| 696 | + # Check that the callable has at least the expect amount of positional-only |
| 697 | + # arguments (and no extra positional-only arguments without default values) |
| 698 | + if not has_varargs: |
| 699 | + protocol_args = [ |
| 700 | + param |
| 701 | + for param in protocol_sig.parameters.values() |
| 702 | + if param.kind |
| 703 | + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) |
| 704 | + ] |
| 705 | + subject_args = [ |
| 706 | + param |
| 707 | + for param in subject_sig.parameters.values() |
| 708 | + if param.kind |
| 709 | + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) |
| 710 | + ] |
| 711 | + |
| 712 | + # Remove the "self" parameter from the protocol arguments to match |
| 713 | + if protocol_type == "instance": |
| 714 | + protocol_args.pop(0) |
| 715 | + |
| 716 | + for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args): |
| 717 | + if protocol_arg is None: |
| 718 | + if subject_arg.default is Parameter.empty: |
| 719 | + raise TypeCheckError("has too many mandatory positional arguments") |
| 720 | + |
| 721 | + break |
| 722 | + |
| 723 | + if subject_arg is None: |
| 724 | + raise TypeCheckError("has too few positional arguments") |
| 725 | + |
| 726 | + if ( |
| 727 | + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD |
| 728 | + and subject_arg.kind is Parameter.POSITIONAL_ONLY |
| 729 | + ): |
| 730 | + raise TypeCheckError( |
| 731 | + f"has an argument ({subject_arg.name}) that should not be " |
| 732 | + f"positional-only" |
| 733 | + ) |
| 734 | + |
| 735 | + if ( |
| 736 | + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD |
| 737 | + and protocol_arg.name != subject_arg.name |
| 738 | + ): |
| 739 | + raise TypeCheckError( |
| 740 | + f"has a positional argument ({subject_arg.name}) that should be " |
| 741 | + f"named {protocol_arg.name!r} at this position" |
| 742 | + ) |
660 | 743 |
|
661 | | - for attrname in typing_extensions.get_protocol_members(origin_type): |
662 | | - member = getattr(origin_type, attrname, None) |
663 | | - |
664 | | - if callable(member): |
665 | | - signature = inspect.signature(member) |
666 | | - argtypes = [ |
667 | | - (p.annotation if p.annotation is not Parameter.empty else Any) |
668 | | - for p in signature.parameters.values() |
669 | | - if p.kind is not Parameter.KEYWORD_ONLY |
670 | | - ] or Ellipsis |
671 | | - return_annotation = ( |
672 | | - signature.return_annotation |
673 | | - if signature.return_annotation is not Parameter.empty |
674 | | - else Any |
| 744 | + protocol_kwonlyargs = { |
| 745 | + param.name: param |
| 746 | + for param in protocol_sig.parameters.values() |
| 747 | + if param.kind is Parameter.KEYWORD_ONLY |
| 748 | + } |
| 749 | + subject_kwonlyargs = { |
| 750 | + param.name: param |
| 751 | + for param in subject_sig.parameters.values() |
| 752 | + if param.kind is Parameter.KEYWORD_ONLY |
| 753 | + } |
| 754 | + if not subject_has_varkwargs: |
| 755 | + # Check that the signature has at least the required keyword-only arguments, and |
| 756 | + # no extra mandatory keyword-only arguments |
| 757 | + if missing_kwonlyargs := [ |
| 758 | + param.name |
| 759 | + for param in protocol_kwonlyargs.values() |
| 760 | + if param.name not in subject_kwonlyargs |
| 761 | + ]: |
| 762 | + raise TypeCheckError( |
| 763 | + "is missing keyword-only arguments: " + ", ".join(missing_kwonlyargs) |
675 | 764 | ) |
676 | | - expected_methods[attrname] = argtypes, return_annotation |
677 | | - else: |
678 | | - try: |
679 | | - expected_noncallable_members[attrname] = origin_annotations[attrname] |
680 | | - except KeyError: |
681 | | - expected_noncallable_members[attrname] = member |
682 | 765 |
|
683 | | - subject_annotations = typing.get_type_hints(subject) |
| 766 | + if not protocol_has_varkwargs: |
| 767 | + if extra_kwonlyargs := [ |
| 768 | + param.name |
| 769 | + for param in subject_kwonlyargs.values() |
| 770 | + if param.default is Parameter.empty |
| 771 | + and param.name not in protocol_kwonlyargs |
| 772 | + ]: |
| 773 | + raise TypeCheckError( |
| 774 | + "has mandatory keyword-only arguments not present in the protocol: " |
| 775 | + + ", ".join(extra_kwonlyargs) |
| 776 | + ) |
684 | 777 |
|
685 | | - # Check that all required methods are present and their signatures are compatible |
686 | | - result_map = protocol_check_cache.setdefault(subject, {}) |
687 | | - try: |
688 | | - for attrname, callable_args in expected_methods.items(): |
| 778 | + |
| 779 | +def check_protocol( |
| 780 | + value: Any, |
| 781 | + origin_type: Any, |
| 782 | + args: tuple[Any, ...], |
| 783 | + memo: TypeCheckMemo, |
| 784 | +) -> None: |
| 785 | + origin_annotations = typing.get_type_hints(origin_type) |
| 786 | + for attrname in sorted(typing_extensions.get_protocol_members(origin_type)): |
| 787 | + if (annotation := origin_annotations.get(attrname)) is not None: |
689 | 788 | try: |
690 | | - method = getattr(subject, attrname) |
| 789 | + subject_member = getattr(value, attrname) |
691 | 790 | except AttributeError: |
692 | | - if attrname in subject_annotations: |
693 | | - raise TypeCheckError( |
694 | | - f"is not compatible with the {origin_type.__qualname__} protocol " |
695 | | - f"because its {attrname!r} attribute is not a method" |
696 | | - ) from None |
697 | | - else: |
698 | | - raise TypeCheckError( |
699 | | - f"is not compatible with the {origin_type.__qualname__} protocol " |
700 | | - f"because it has no method named {attrname!r}" |
701 | | - ) from None |
702 | | - |
703 | | - if not callable(method): |
704 | 791 | raise TypeCheckError( |
705 | | - f"is not compatible with the {origin_type.__qualname__} protocol " |
706 | | - f"because its {attrname!r} attribute is not a callable" |
707 | | - ) |
| 792 | + f"is not compatible with the {origin_type.__qualname__} " |
| 793 | + f"protocol because it has no attribute named {attrname!r}" |
| 794 | + ) from None |
708 | 795 |
|
709 | | - # TODO: raise exception on added keyword-only arguments without defaults |
710 | 796 | try: |
711 | | - check_callable(method, Callable, callable_args, memo) |
| 797 | + check_type_internal(subject_member, annotation, memo) |
712 | 798 | except TypeCheckError as exc: |
713 | 799 | raise TypeCheckError( |
714 | | - f"is not compatible with the {origin_type.__qualname__} protocol " |
715 | | - f"because its {attrname!r} method {exc}" |
| 800 | + f"is not compatible with the {origin_type.__qualname__} " |
| 801 | + f"protocol because its {attrname!r} attribute {exc}" |
| 802 | + ) from None |
| 803 | + elif callable(getattr(origin_type, attrname)): |
| 804 | + try: |
| 805 | + subject_member = getattr(value, attrname) |
| 806 | + except AttributeError: |
| 807 | + raise TypeCheckError( |
| 808 | + f"is not compatible with the {origin_type.__qualname__} " |
| 809 | + f"protocol because it has no method named {attrname!r}" |
716 | 810 | ) from None |
717 | 811 |
|
718 | | - # Check that all required non-callable members are present |
719 | | - for attrname in expected_noncallable_members: |
720 | | - # TODO: implement assignability checks for non-callable members |
721 | | - if attrname not in subject_annotations and not hasattr(subject, attrname): |
| 812 | + if not callable(subject_member): |
722 | 813 | raise TypeCheckError( |
723 | | - f"is not compatible with the {origin_type.__qualname__} protocol " |
724 | | - f"because it has no attribute named {attrname!r}" |
| 814 | + f"is not compatible with the {origin_type.__qualname__} " |
| 815 | + f"protocol because its {attrname!r} attribute is not a callable" |
725 | 816 | ) |
726 | | - except TypeCheckError as exc: |
727 | | - result_map[origin_type] = exc |
728 | | - raise |
729 | | - else: |
730 | | - result_map[origin_type] = None |
| 817 | + |
| 818 | + # TODO: implement assignability checks for parameter and return value |
| 819 | + # annotations |
| 820 | + try: |
| 821 | + check_signature_compatible(subject_member, origin_type, attrname) |
| 822 | + except TypeCheckError as exc: |
| 823 | + raise TypeCheckError( |
| 824 | + f"is not compatible with the {origin_type.__qualname__} " |
| 825 | + f"protocol because its {attrname!r} method {exc}" |
| 826 | + ) from None |
731 | 827 |
|
732 | 828 |
|
733 | 829 | def check_byteslike( |
|
0 commit comments