1+ from __future__ import annotations
2+
13from collections import deque , OrderedDict
4+ from collections .abc import Callable
25from math import inf
36
7+ from types import TracebackType
8+ from typing import (
9+ Any ,
10+ Generic ,
11+ NoReturn ,
12+ TypeVar ,
13+ TYPE_CHECKING ,
14+ Tuple , # only needed for typechecking on <3.9
15+ )
16+
417import attr
518from outcome import Error , Value
619
720from .abc import SendChannel , ReceiveChannel , Channel
821from ._util import generic_function , NoPublicConstructor
922
1023import trio
11- from ._core import enable_ki_protection
24+ from ._core import enable_ki_protection , Task , Abort , RaiseCancelT
25+
26+ # A regular invariant generic type
27+ T = TypeVar ("T" )
28+
29+ # The type of object produced by a ReceiveChannel (covariant because
30+ # ReceiveChannel[Derived] can be passed to someone expecting
31+ # ReceiveChannel[Base])
32+ ReceiveType = TypeVar ("ReceiveType" , covariant = True )
33+
34+ # The type of object accepted by a SendChannel (contravariant because
35+ # SendChannel[Base] can be passed to someone expecting
36+ # SendChannel[Derived])
37+ SendType = TypeVar ("SendType" , contravariant = True )
1238
39+ # Temporary TypeVar needed until mypy release supports Self as a type
40+ SelfT = TypeVar ("SelfT" )
1341
14- @generic_function
15- def open_memory_channel (max_buffer_size ):
42+
43+ def _open_memory_channel (
44+ max_buffer_size : int ,
45+ ) -> tuple [MemorySendChannel [T ], MemoryReceiveChannel [T ]]:
1646 """Open a channel for passing objects between tasks within a process.
1747
1848 Memory channels are lightweight, cheap to allocate, and entirely
@@ -68,36 +98,57 @@ def open_memory_channel(max_buffer_size):
6898 raise TypeError ("max_buffer_size must be an integer or math.inf" )
6999 if max_buffer_size < 0 :
70100 raise ValueError ("max_buffer_size must be >= 0" )
71- state = MemoryChannelState (max_buffer_size )
101+ state : MemoryChannelState [ T ] = MemoryChannelState (max_buffer_size )
72102 return (
73- MemorySendChannel ._create (state ),
74- MemoryReceiveChannel ._create (state ),
103+ MemorySendChannel [ T ] ._create (state ),
104+ MemoryReceiveChannel [ T ] ._create (state ),
75105 )
76106
77107
108+ # This workaround requires python3.9+, once older python versions are not supported
109+ # or there's a better way of achieving type-checking on a generic factory function,
110+ # it could replace the normal function header
111+ if TYPE_CHECKING :
112+ # written as a class so you can say open_memory_channel[int](5)
113+ # Need to use Tuple instead of tuple due to CI check running on 3.8
114+ class open_memory_channel (Tuple [MemorySendChannel [T ], MemoryReceiveChannel [T ]]):
115+ def __new__ ( # type: ignore[misc] # "must return a subtype"
116+ cls , max_buffer_size : int
117+ ) -> tuple [MemorySendChannel [T ], MemoryReceiveChannel [T ]]:
118+ return _open_memory_channel (max_buffer_size )
119+
120+ def __init__ (self , max_buffer_size : int ):
121+ ...
122+
123+ else :
124+ # apply the generic_function decorator to make open_memory_channel indexable
125+ # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
126+ open_memory_channel = generic_function (_open_memory_channel )
127+
128+
78129@attr .s (frozen = True , slots = True )
79130class MemoryChannelStats :
80- current_buffer_used = attr .ib ()
81- max_buffer_size = attr .ib ()
82- open_send_channels = attr .ib ()
83- open_receive_channels = attr .ib ()
84- tasks_waiting_send = attr .ib ()
85- tasks_waiting_receive = attr .ib ()
131+ current_buffer_used : int = attr .ib ()
132+ max_buffer_size : int = attr .ib ()
133+ open_send_channels : int = attr .ib ()
134+ open_receive_channels : int = attr .ib ()
135+ tasks_waiting_send : int = attr .ib ()
136+ tasks_waiting_receive : int = attr .ib ()
86137
87138
88139@attr .s (slots = True )
89- class MemoryChannelState :
90- max_buffer_size = attr .ib ()
91- data = attr .ib (factory = deque )
140+ class MemoryChannelState ( Generic [ T ]) :
141+ max_buffer_size : int = attr .ib ()
142+ data : deque [ T ] = attr .ib (factory = deque )
92143 # Counts of open endpoints using this state
93- open_send_channels = attr .ib (default = 0 )
94- open_receive_channels = attr .ib (default = 0 )
144+ open_send_channels : int = attr .ib (default = 0 )
145+ open_receive_channels : int = attr .ib (default = 0 )
95146 # {task: value}
96- send_tasks = attr .ib (factory = OrderedDict )
147+ send_tasks : OrderedDict [ Task , T ] = attr .ib (factory = OrderedDict )
97148 # {task: None}
98- receive_tasks = attr .ib (factory = OrderedDict )
149+ receive_tasks : OrderedDict [ Task , None ] = attr .ib (factory = OrderedDict )
99150
100- def statistics (self ):
151+ def statistics (self ) -> MemoryChannelStats :
101152 return MemoryChannelStats (
102153 current_buffer_used = len (self .data ),
103154 max_buffer_size = self .max_buffer_size ,
@@ -109,28 +160,28 @@ def statistics(self):
109160
110161
111162@attr .s (eq = False , repr = False )
112- class MemorySendChannel (SendChannel , metaclass = NoPublicConstructor ):
113- _state = attr .ib ()
114- _closed = attr .ib (default = False )
163+ class MemorySendChannel (SendChannel [ SendType ] , metaclass = NoPublicConstructor ):
164+ _state : MemoryChannelState [ SendType ] = attr .ib ()
165+ _closed : bool = attr .ib (default = False )
115166 # This is just the tasks waiting on *this* object. As compared to
116167 # self._state.send_tasks, which includes tasks from this object and
117168 # all clones.
118- _tasks = attr .ib (factory = set )
169+ _tasks : set [ Task ] = attr .ib (factory = set )
119170
120- def __attrs_post_init__ (self ):
171+ def __attrs_post_init__ (self ) -> None :
121172 self ._state .open_send_channels += 1
122173
123- def __repr__ (self ):
174+ def __repr__ (self ) -> str :
124175 return "<send channel at {:#x}, using buffer at {:#x}>" .format (
125176 id (self ), id (self ._state )
126177 )
127178
128- def statistics (self ):
179+ def statistics (self ) -> MemoryChannelStats :
129180 # XX should we also report statistics specific to this object?
130181 return self ._state .statistics ()
131182
132183 @enable_ki_protection
133- def send_nowait (self , value ) :
184+ def send_nowait (self , value : SendType ) -> None :
134185 """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
135186 full, raises `WouldBlock` instead of blocking.
136187
@@ -150,7 +201,7 @@ def send_nowait(self, value):
150201 raise trio .WouldBlock
151202
152203 @enable_ki_protection
153- async def send (self , value ) :
204+ async def send (self , value : SendType ) -> None :
154205 """See `SendChannel.send <trio.abc.SendChannel.send>`.
155206
156207 Memory channels allow multiple tasks to call `send` at the same time.
@@ -170,15 +221,16 @@ async def send(self, value):
170221 self ._state .send_tasks [task ] = value
171222 task .custom_sleep_data = self
172223
173- def abort_fn (_ ) :
224+ def abort_fn (_ : RaiseCancelT ) -> Abort :
174225 self ._tasks .remove (task )
175226 del self ._state .send_tasks [task ]
176227 return trio .lowlevel .Abort .SUCCEEDED
177228
178229 await trio .lowlevel .wait_task_rescheduled (abort_fn )
179230
231+ # Return type must be stringified or use a TypeVar
180232 @enable_ki_protection
181- def clone (self ):
233+ def clone (self ) -> "MemorySendChannel[SendType]" :
182234 """Clone this send channel object.
183235
184236 This returns a new `MemorySendChannel` object, which acts as a
@@ -206,14 +258,19 @@ def clone(self):
206258 raise trio .ClosedResourceError
207259 return MemorySendChannel ._create (self ._state )
208260
209- def __enter__ (self ) :
261+ def __enter__ (self : SelfT ) -> SelfT :
210262 return self
211263
212- def __exit__ (self , exc_type , exc_val , exc_tb ):
264+ def __exit__ (
265+ self ,
266+ exc_type : type [BaseException ] | None ,
267+ exc_val : BaseException | None ,
268+ exc_tb : TracebackType | None ,
269+ ) -> None :
213270 self .close ()
214271
215272 @enable_ki_protection
216- def close (self ):
273+ def close (self ) -> None :
217274 """Close this send channel object synchronously.
218275
219276 All channel objects have an asynchronous `~.AsyncResource.aclose` method.
@@ -241,30 +298,30 @@ def close(self):
241298 self ._state .receive_tasks .clear ()
242299
243300 @enable_ki_protection
244- async def aclose (self ):
301+ async def aclose (self ) -> None :
245302 self .close ()
246303 await trio .lowlevel .checkpoint ()
247304
248305
249306@attr .s (eq = False , repr = False )
250- class MemoryReceiveChannel (ReceiveChannel , metaclass = NoPublicConstructor ):
251- _state = attr .ib ()
252- _closed = attr .ib (default = False )
253- _tasks = attr .ib (factory = set )
307+ class MemoryReceiveChannel (ReceiveChannel [ ReceiveType ] , metaclass = NoPublicConstructor ):
308+ _state : MemoryChannelState [ ReceiveType ] = attr .ib ()
309+ _closed : bool = attr .ib (default = False )
310+ _tasks : set [ trio . _core . _run . Task ] = attr .ib (factory = set )
254311
255- def __attrs_post_init__ (self ):
312+ def __attrs_post_init__ (self ) -> None :
256313 self ._state .open_receive_channels += 1
257314
258- def statistics (self ):
315+ def statistics (self ) -> MemoryChannelStats :
259316 return self ._state .statistics ()
260317
261- def __repr__ (self ):
318+ def __repr__ (self ) -> str :
262319 return "<receive channel at {:#x}, using buffer at {:#x}>" .format (
263320 id (self ), id (self ._state )
264321 )
265322
266323 @enable_ki_protection
267- def receive_nowait (self ):
324+ def receive_nowait (self ) -> ReceiveType :
268325 """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
269326 ready to receive, raises `WouldBlock` instead of blocking.
270327
@@ -284,7 +341,7 @@ def receive_nowait(self):
284341 raise trio .WouldBlock
285342
286343 @enable_ki_protection
287- async def receive (self ):
344+ async def receive (self ) -> ReceiveType :
288345 """See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
289346
290347 Memory channels allow multiple tasks to call `receive` at the same
@@ -306,15 +363,17 @@ async def receive(self):
306363 self ._state .receive_tasks [task ] = None
307364 task .custom_sleep_data = self
308365
309- def abort_fn (_ ) :
366+ def abort_fn (_ : RaiseCancelT ) -> Abort :
310367 self ._tasks .remove (task )
311368 del self ._state .receive_tasks [task ]
312369 return trio .lowlevel .Abort .SUCCEEDED
313370
314- return await trio .lowlevel .wait_task_rescheduled (abort_fn )
371+ # Not strictly guaranteed to return ReceiveType, but will do so unless
372+ # you intentionally reschedule with a bad value.
373+ return await trio .lowlevel .wait_task_rescheduled (abort_fn ) # type: ignore[no-any-return]
315374
316375 @enable_ki_protection
317- def clone (self ):
376+ def clone (self ) -> "MemoryReceiveChannel[ReceiveType]" :
318377 """Clone this receive channel object.
319378
320379 This returns a new `MemoryReceiveChannel` object, which acts as a
@@ -345,14 +404,19 @@ def clone(self):
345404 raise trio .ClosedResourceError
346405 return MemoryReceiveChannel ._create (self ._state )
347406
348- def __enter__ (self ) :
407+ def __enter__ (self : SelfT ) -> SelfT :
349408 return self
350409
351- def __exit__ (self , exc_type , exc_val , exc_tb ):
410+ def __exit__ (
411+ self ,
412+ exc_type : type [BaseException ] | None ,
413+ exc_val : BaseException | None ,
414+ exc_tb : TracebackType | None ,
415+ ) -> None :
352416 self .close ()
353417
354418 @enable_ki_protection
355- def close (self ):
419+ def close (self ) -> None :
356420 """Close this receive channel object synchronously.
357421
358422 All channel objects have an asynchronous `~.AsyncResource.aclose` method.
@@ -381,6 +445,6 @@ def close(self):
381445 self ._state .data .clear ()
382446
383447 @enable_ki_protection
384- async def aclose (self ):
448+ async def aclose (self ) -> None :
385449 self .close ()
386450 await trio .lowlevel .checkpoint ()
0 commit comments