Skip to content

Commit d5fd09f

Browse files
committed
adress comments, removing redundant Generic, fixing open_memory_channel return type, and adding trio-typings class workaround to get type-checking on open_memory_channel by hiding it behind a TYPE_CHECKING guard
1 parent c653319 commit d5fd09f

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

trio/_channel.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Generic,
1111
NoReturn,
1212
TypeVar,
13+
TYPE_CHECKING,
1314
)
1415

1516
import attr
@@ -38,8 +39,7 @@
3839
SelfT = TypeVar("SelfT")
3940

4041

41-
@generic_function
42-
def open_memory_channel(
42+
def _open_memory_channel(
4343
max_buffer_size: int,
4444
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
4545
"""Open a channel for passing objects between tasks within a process.
@@ -99,11 +99,31 @@ def open_memory_channel(
9999
raise ValueError("max_buffer_size must be >= 0")
100100
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
101101
return (
102-
MemorySendChannel._create(state),
103-
MemoryReceiveChannel._create(state),
102+
MemorySendChannel[T]._create(state),
103+
MemoryReceiveChannel[T]._create(state),
104104
)
105105

106106

107+
# This workaround requires python3.9+, once older python versions are not supported
108+
# or there's a better way of achieving type-checking on a generic factory function,
109+
# it could replace the normal function header
110+
if TYPE_CHECKING:
111+
# written as a class so you can say open_memory_channel[int](5)
112+
class open_memory_channel(tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
113+
def __new__( # type: ignore[misc] # "must return a subtype"
114+
cls, max_buffer_size: int
115+
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
116+
return _open_memory_channel(max_buffer_size)
117+
118+
def __init__(self, max_buffer_size: int):
119+
...
120+
121+
else:
122+
# apply the generic_function decorator to make open_memory_channel indexable
123+
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
124+
open_memory_channel = generic_function(_open_memory_channel)
125+
126+
107127
@attr.s(frozen=True, slots=True)
108128
class MemoryChannelStats:
109129
current_buffer_used: int = attr.ib()
@@ -138,9 +158,7 @@ def statistics(self) -> MemoryChannelStats:
138158

139159

140160
@attr.s(eq=False, repr=False)
141-
class MemorySendChannel(
142-
SendChannel[SendType], Generic[SendType], metaclass=NoPublicConstructor
143-
):
161+
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
144162
_state: MemoryChannelState[SendType] = attr.ib()
145163
_closed: bool = attr.ib(default=False)
146164
# This is just the tasks waiting on *this* object. As compared to
@@ -284,9 +302,7 @@ async def aclose(self) -> None:
284302

285303

286304
@attr.s(eq=False, repr=False)
287-
class MemoryReceiveChannel(
288-
ReceiveChannel[ReceiveType], Generic[ReceiveType], metaclass=NoPublicConstructor
289-
):
305+
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
290306
_state: MemoryChannelState[ReceiveType] = attr.ib()
291307
_closed: bool = attr.ib(default=False)
292308
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)

trio/_core/_traps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class Abort(enum.Enum):
6464
class WaitTaskRescheduled:
6565
abort_func = attr.ib()
6666

67-
RaiseCancelT = Callable[[], NoReturn] # TypeAlias
67+
68+
RaiseCancelT = Callable[[], NoReturn] # TypeAlias
6869

6970
# Can this function be retyped to return something other than Any?
7071
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:

0 commit comments

Comments
 (0)