1919
2020from typing import Callable , Optional , Type , Union
2121
22+ from mypy .nodes import ARG_POS , Decorator , MemberExpr
2223from mypy .plugin import FunctionContext , MethodContext , MethodSigContext , Plugin
2324from mypy .typeops import bind_self
2425from mypy .types import AnyType , CallableType , Instance
@@ -46,15 +47,16 @@ class _AdjustArguments(object):
4647 """
4748
4849 def __call__ (self , ctx : FunctionContext ) -> MypyType :
50+ defn = ctx .arg_types [0 ][0 ]
4951 is_defined_by_class = (
50- isinstance (ctx . arg_types [ 0 ][ 0 ] , CallableType ) and
51- not ctx . arg_types [ 0 ][ 0 ] .arg_types and
52- isinstance (ctx . arg_types [ 0 ][ 0 ] .ret_type , Instance )
52+ isinstance (defn , CallableType ) and
53+ not defn .arg_types and
54+ isinstance (defn .ret_type , Instance )
5355 )
5456
5557 if is_defined_by_class :
5658 return self ._adjust_protocol_arguments (ctx )
57- elif isinstance (ctx . arg_types [ 0 ][ 0 ] , CallableType ):
59+ elif isinstance (defn , CallableType ):
5860 return self ._adjust_function_arguments (ctx )
5961 return ctx .default_return_type
6062
@@ -144,12 +146,87 @@ class _AdjustInstanceSignature(object):
144146 """
145147
146148 def __call__ (self , ctx : MethodContext ) -> MypyType :
149+ if not isinstance (ctx .type , Instance ):
150+ return ctx .default_return_type
151+ if not isinstance (ctx .default_return_type , CallableType ):
152+ return ctx .default_return_type
153+
147154 instance_type = self ._adjust_typeclass_callable (ctx )
148155 self ._adjust_typeclass_type (ctx , instance_type )
149156 if isinstance (instance_type , Instance ):
150157 self ._add_supports_metadata (ctx , instance_type )
151158 return ctx .default_return_type
152159
160+ @classmethod
161+ def from_function_decorator (cls , ctx : FunctionContext ) -> MypyType :
162+ """
163+ It is used when ``.instance`` is used without params as a decorator.
164+
165+ Like:
166+
167+ .. code:: python
168+
169+ @some.instance
170+ def _some_str(instance: str) -> str:
171+ ...
172+
173+ """
174+ is_decorator = (
175+ isinstance (ctx .context , Decorator ) and
176+ len (ctx .context .decorators ) == 1 and
177+ isinstance (ctx .context .decorators [0 ], MemberExpr ) and
178+ ctx .context .decorators [0 ].name == 'instance'
179+ )
180+ if not is_decorator :
181+ return ctx .default_return_type
182+
183+ passed_function = ctx .arg_types [0 ][0 ]
184+ assert isinstance (passed_function , CallableType )
185+
186+ if not passed_function .arg_types :
187+ return ctx .default_return_type
188+
189+ annotation_type = passed_function .arg_types [0 ]
190+ if isinstance (annotation_type , Instance ):
191+ if annotation_type .type and annotation_type .type .is_protocol :
192+ ctx .api .fail (
193+ 'Protocols must be passed with `is_protocol=True`' ,
194+ ctx .context ,
195+ )
196+ return ctx .default_return_type
197+ else :
198+ ctx .api .fail (
199+ 'Only simple instance types are allowed, got: {0}' .format (
200+ annotation_type ,
201+ ),
202+ ctx .context ,
203+ )
204+ return ctx .default_return_type
205+
206+ ret_type = CallableType (
207+ arg_types = [passed_function ],
208+ arg_kinds = [ARG_POS ],
209+ arg_names = [None ],
210+ ret_type = AnyType (TypeOfAny .implementation_artifact ),
211+ fallback = passed_function .fallback ,
212+ )
213+ instance_type = ctx .api .expr_checker .accept ( # type: ignore
214+ ctx .context .decorators [0 ].expr , # type: ignore
215+ )
216+
217+ # We need to change the `ctx` type from `Function` to `Method`:
218+ return cls ()(MethodContext (
219+ type = instance_type ,
220+ arg_types = ctx .arg_types ,
221+ arg_kinds = ctx .arg_kinds ,
222+ arg_names = ctx .arg_names ,
223+ args = ctx .args ,
224+ callee_arg_names = ctx .callee_arg_names ,
225+ default_return_type = ret_type ,
226+ context = ctx .context ,
227+ api = ctx .api ,
228+ ))
229+
153230 def _adjust_typeclass_callable (
154231 self ,
155232 ctx : MethodContext ,
@@ -302,6 +379,9 @@ def get_function_hook(
302379 """Here we adjust the typeclass constructor."""
303380 if fullname == 'classes._typeclass.typeclass' :
304381 return _AdjustArguments ()
382+ if fullname == 'instance of _TypeClass' :
383+ # `@some.instance` call without params:
384+ return _AdjustInstanceSignature .from_function_decorator
305385 return None
306386
307387 def get_method_hook (
@@ -310,6 +390,7 @@ def get_method_hook(
310390 ) -> Optional [Callable [[MethodContext ], MypyType ]]:
311391 """Here we adjust the typeclass with new allowed types."""
312392 if fullname == 'classes._typeclass._TypeClass.instance' :
393+ # `@some.instance` call with explicit params:
313394 return _AdjustInstanceSignature ()
314395 return None
315396
0 commit comments