Skip to content

Commit 798e87a

Browse files
committed
Add generic typing support for MemorySendChannel and MemoryReceiveChannel
1 parent ca850f6 commit 798e87a

File tree

4 files changed

+123
-52
lines changed

4 files changed

+123
-52
lines changed

trio/_channel.py

Lines changed: 115 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,48 @@
1+
from __future__ import annotations
2+
13
from collections import deque, OrderedDict
4+
from collections.abc import Callable
25
from math import inf
36

7+
from types import TracebackType
8+
from typing import (
9+
Any,
10+
Generic,
11+
NoReturn,
12+
TypeVar,
13+
TYPE_CHECKING,
14+
Tuple, # only needed for typechecking on <3.9
15+
)
16+
417
import attr
518
from outcome import Error, Value
619

720
from .abc import SendChannel, ReceiveChannel, Channel
821
from ._util import generic_function, NoPublicConstructor
922

1023
import trio
11-
from ._core import enable_ki_protection
24+
from ._core import enable_ki_protection, Task, Abort, RaiseCancelT
25+
26+
# A regular invariant generic type
27+
T = TypeVar("T")
28+
29+
# The type of object produced by a ReceiveChannel (covariant because
30+
# ReceiveChannel[Derived] can be passed to someone expecting
31+
# ReceiveChannel[Base])
32+
ReceiveType = TypeVar("ReceiveType", covariant=True)
33+
34+
# The type of object accepted by a SendChannel (contravariant because
35+
# SendChannel[Base] can be passed to someone expecting
36+
# SendChannel[Derived])
37+
SendType = TypeVar("SendType", contravariant=True)
1238

39+
# Temporary TypeVar needed until mypy release supports Self as a type
40+
SelfT = TypeVar("SelfT")
1341

14-
@generic_function
15-
def open_memory_channel(max_buffer_size):
42+
43+
def _open_memory_channel(
44+
max_buffer_size: int,
45+
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
1646
"""Open a channel for passing objects between tasks within a process.
1747
1848
Memory channels are lightweight, cheap to allocate, and entirely
@@ -68,36 +98,57 @@ def open_memory_channel(max_buffer_size):
6898
raise TypeError("max_buffer_size must be an integer or math.inf")
6999
if max_buffer_size < 0:
70100
raise ValueError("max_buffer_size must be >= 0")
71-
state = MemoryChannelState(max_buffer_size)
101+
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
72102
return (
73-
MemorySendChannel._create(state),
74-
MemoryReceiveChannel._create(state),
103+
MemorySendChannel[T]._create(state),
104+
MemoryReceiveChannel[T]._create(state),
75105
)
76106

77107

108+
# This workaround requires python3.9+, once older python versions are not supported
109+
# or there's a better way of achieving type-checking on a generic factory function,
110+
# it could replace the normal function header
111+
if TYPE_CHECKING:
112+
# written as a class so you can say open_memory_channel[int](5)
113+
# Need to use Tuple instead of tuple due to CI check running on 3.8
114+
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
115+
def __new__( # type: ignore[misc] # "must return a subtype"
116+
cls, max_buffer_size: int
117+
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
118+
return _open_memory_channel(max_buffer_size)
119+
120+
def __init__(self, max_buffer_size: int):
121+
...
122+
123+
else:
124+
# apply the generic_function decorator to make open_memory_channel indexable
125+
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
126+
open_memory_channel = generic_function(_open_memory_channel)
127+
128+
78129
@attr.s(frozen=True, slots=True)
79130
class MemoryChannelStats:
80-
current_buffer_used = attr.ib()
81-
max_buffer_size = attr.ib()
82-
open_send_channels = attr.ib()
83-
open_receive_channels = attr.ib()
84-
tasks_waiting_send = attr.ib()
85-
tasks_waiting_receive = attr.ib()
131+
current_buffer_used: int = attr.ib()
132+
max_buffer_size: int = attr.ib()
133+
open_send_channels: int = attr.ib()
134+
open_receive_channels: int = attr.ib()
135+
tasks_waiting_send: int = attr.ib()
136+
tasks_waiting_receive: int = attr.ib()
86137

87138

88139
@attr.s(slots=True)
89-
class MemoryChannelState:
90-
max_buffer_size = attr.ib()
91-
data = attr.ib(factory=deque)
140+
class MemoryChannelState(Generic[T]):
141+
max_buffer_size: int = attr.ib()
142+
data: deque[T] = attr.ib(factory=deque)
92143
# Counts of open endpoints using this state
93-
open_send_channels = attr.ib(default=0)
94-
open_receive_channels = attr.ib(default=0)
144+
open_send_channels: int = attr.ib(default=0)
145+
open_receive_channels: int = attr.ib(default=0)
95146
# {task: value}
96-
send_tasks = attr.ib(factory=OrderedDict)
147+
send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
97148
# {task: None}
98-
receive_tasks = attr.ib(factory=OrderedDict)
149+
receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict)
99150

100-
def statistics(self):
151+
def statistics(self) -> MemoryChannelStats:
101152
return MemoryChannelStats(
102153
current_buffer_used=len(self.data),
103154
max_buffer_size=self.max_buffer_size,
@@ -109,28 +160,28 @@ def statistics(self):
109160

110161

111162
@attr.s(eq=False, repr=False)
112-
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
113-
_state = attr.ib()
114-
_closed = attr.ib(default=False)
163+
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
164+
_state: MemoryChannelState[SendType] = attr.ib()
165+
_closed: bool = attr.ib(default=False)
115166
# This is just the tasks waiting on *this* object. As compared to
116167
# self._state.send_tasks, which includes tasks from this object and
117168
# all clones.
118-
_tasks = attr.ib(factory=set)
169+
_tasks: set[Task] = attr.ib(factory=set)
119170

120-
def __attrs_post_init__(self):
171+
def __attrs_post_init__(self) -> None:
121172
self._state.open_send_channels += 1
122173

123-
def __repr__(self):
174+
def __repr__(self) -> str:
124175
return "<send channel at {:#x}, using buffer at {:#x}>".format(
125176
id(self), id(self._state)
126177
)
127178

128-
def statistics(self):
179+
def statistics(self) -> MemoryChannelStats:
129180
# XX should we also report statistics specific to this object?
130181
return self._state.statistics()
131182

132183
@enable_ki_protection
133-
def send_nowait(self, value):
184+
def send_nowait(self, value: SendType) -> None:
134185
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
135186
full, raises `WouldBlock` instead of blocking.
136187
@@ -150,7 +201,7 @@ def send_nowait(self, value):
150201
raise trio.WouldBlock
151202

152203
@enable_ki_protection
153-
async def send(self, value):
204+
async def send(self, value: SendType) -> None:
154205
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
155206
156207
Memory channels allow multiple tasks to call `send` at the same time.
@@ -170,15 +221,16 @@ async def send(self, value):
170221
self._state.send_tasks[task] = value
171222
task.custom_sleep_data = self
172223

173-
def abort_fn(_):
224+
def abort_fn(_: RaiseCancelT) -> Abort:
174225
self._tasks.remove(task)
175226
del self._state.send_tasks[task]
176227
return trio.lowlevel.Abort.SUCCEEDED
177228

178229
await trio.lowlevel.wait_task_rescheduled(abort_fn)
179230

231+
# Return type must be stringified or use a TypeVar
180232
@enable_ki_protection
181-
def clone(self):
233+
def clone(self) -> "MemorySendChannel[SendType]":
182234
"""Clone this send channel object.
183235
184236
This returns a new `MemorySendChannel` object, which acts as a
@@ -206,14 +258,19 @@ def clone(self):
206258
raise trio.ClosedResourceError
207259
return MemorySendChannel._create(self._state)
208260

209-
def __enter__(self):
261+
def __enter__(self: SelfT) -> SelfT:
210262
return self
211263

212-
def __exit__(self, exc_type, exc_val, exc_tb):
264+
def __exit__(
265+
self,
266+
exc_type: type[BaseException] | None,
267+
exc_val: BaseException | None,
268+
exc_tb: TracebackType | None,
269+
) -> None:
213270
self.close()
214271

215272
@enable_ki_protection
216-
def close(self):
273+
def close(self) -> None:
217274
"""Close this send channel object synchronously.
218275
219276
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
@@ -241,30 +298,30 @@ def close(self):
241298
self._state.receive_tasks.clear()
242299

243300
@enable_ki_protection
244-
async def aclose(self):
301+
async def aclose(self) -> None:
245302
self.close()
246303
await trio.lowlevel.checkpoint()
247304

248305

249306
@attr.s(eq=False, repr=False)
250-
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
251-
_state = attr.ib()
252-
_closed = attr.ib(default=False)
253-
_tasks = attr.ib(factory=set)
307+
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
308+
_state: MemoryChannelState[ReceiveType] = attr.ib()
309+
_closed: bool = attr.ib(default=False)
310+
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)
254311

255-
def __attrs_post_init__(self):
312+
def __attrs_post_init__(self) -> None:
256313
self._state.open_receive_channels += 1
257314

258-
def statistics(self):
315+
def statistics(self) -> MemoryChannelStats:
259316
return self._state.statistics()
260317

261-
def __repr__(self):
318+
def __repr__(self) -> str:
262319
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
263320
id(self), id(self._state)
264321
)
265322

266323
@enable_ki_protection
267-
def receive_nowait(self):
324+
def receive_nowait(self) -> ReceiveType:
268325
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
269326
ready to receive, raises `WouldBlock` instead of blocking.
270327
@@ -284,7 +341,7 @@ def receive_nowait(self):
284341
raise trio.WouldBlock
285342

286343
@enable_ki_protection
287-
async def receive(self):
344+
async def receive(self) -> ReceiveType:
288345
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
289346
290347
Memory channels allow multiple tasks to call `receive` at the same
@@ -306,15 +363,17 @@ async def receive(self):
306363
self._state.receive_tasks[task] = None
307364
task.custom_sleep_data = self
308365

309-
def abort_fn(_):
366+
def abort_fn(_: RaiseCancelT) -> Abort:
310367
self._tasks.remove(task)
311368
del self._state.receive_tasks[task]
312369
return trio.lowlevel.Abort.SUCCEEDED
313370

314-
return await trio.lowlevel.wait_task_rescheduled(abort_fn)
371+
# Not strictly guaranteed to return ReceiveType, but will do so unless
372+
# you intentionally reschedule with a bad value.
373+
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return]
315374

316375
@enable_ki_protection
317-
def clone(self):
376+
def clone(self) -> "MemoryReceiveChannel[ReceiveType]":
318377
"""Clone this receive channel object.
319378
320379
This returns a new `MemoryReceiveChannel` object, which acts as a
@@ -345,14 +404,19 @@ def clone(self):
345404
raise trio.ClosedResourceError
346405
return MemoryReceiveChannel._create(self._state)
347406

348-
def __enter__(self):
407+
def __enter__(self: SelfT) -> SelfT:
349408
return self
350409

351-
def __exit__(self, exc_type, exc_val, exc_tb):
410+
def __exit__(
411+
self,
412+
exc_type: type[BaseException] | None,
413+
exc_val: BaseException | None,
414+
exc_tb: TracebackType | None,
415+
) -> None:
352416
self.close()
353417

354418
@enable_ki_protection
355-
def close(self):
419+
def close(self) -> None:
356420
"""Close this receive channel object synchronously.
357421
358422
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
@@ -381,6 +445,6 @@ def close(self):
381445
self._state.data.clear()
382446

383447
@enable_ki_protection
384-
async def aclose(self):
448+
async def aclose(self) -> None:
385449
self.close()
386450
await trio.lowlevel.checkpoint()

trio/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from ._traps import (
5656
cancel_shielded_checkpoint,
5757
Abort,
58+
RaiseCancelT,
5859
wait_task_rescheduled,
5960
temporarily_detach_coroutine_object,
6061
permanently_detach_coroutine_object,

trio/_core/_traps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from . import _run
1010

11+
from typing import Callable, NoReturn, Any
1112

1213
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
1314
# function, but you can inside a generator, and if you decorate your generator
@@ -64,7 +65,11 @@ class WaitTaskRescheduled:
6465
abort_func = attr.ib()
6566

6667

67-
async def wait_task_rescheduled(abort_func):
68+
RaiseCancelT = Callable[[], NoReturn] # TypeAlias
69+
70+
# Should always return the type a Task "expects", unless you willfully reschedule it
71+
# with a bad value.
72+
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
6873
"""Put the current task to sleep, with cancellation support.
6974
7075
This is the lowest-level API for blocking in Trio. Every time a

trio/lowlevel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._core import (
1717
cancel_shielded_checkpoint,
1818
Abort,
19+
RaiseCancelT,
1920
wait_task_rescheduled,
2021
enable_ki_protection,
2122
disable_ki_protection,

0 commit comments

Comments
 (0)