|
10 | 10 | Generic, |
11 | 11 | NoReturn, |
12 | 12 | TypeVar, |
| 13 | + TYPE_CHECKING, |
13 | 14 | ) |
14 | 15 |
|
15 | 16 | import attr |
|
38 | 39 | SelfT = TypeVar("SelfT") |
39 | 40 |
|
40 | 41 |
|
41 | | -@generic_function |
42 | | -def open_memory_channel( |
| 42 | +def _open_memory_channel( |
43 | 43 | max_buffer_size: int, |
44 | 44 | ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: |
45 | 45 | """Open a channel for passing objects between tasks within a process. |
@@ -99,11 +99,31 @@ def open_memory_channel( |
99 | 99 | raise ValueError("max_buffer_size must be >= 0") |
100 | 100 | state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size) |
101 | 101 | return ( |
102 | | - MemorySendChannel._create(state), |
103 | | - MemoryReceiveChannel._create(state), |
| 102 | + MemorySendChannel[T]._create(state), |
| 103 | + MemoryReceiveChannel[T]._create(state), |
104 | 104 | ) |
105 | 105 |
|
106 | 106 |
|
| 107 | +# This workaround requires python3.9+, once older python versions are not supported |
| 108 | +# or there's a better way of achieving type-checking on a generic factory function, |
| 109 | +# it could replace the normal function header |
| 110 | +if TYPE_CHECKING: |
| 111 | + # written as a class so you can say open_memory_channel[int](5) |
| 112 | + class open_memory_channel(tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): |
| 113 | + def __new__( # type: ignore[misc] # "must return a subtype" |
| 114 | + cls, max_buffer_size: int |
| 115 | + ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: |
| 116 | + return _open_memory_channel(max_buffer_size) |
| 117 | + |
| 118 | + def __init__(self, max_buffer_size: int): |
| 119 | + ... |
| 120 | + |
| 121 | +else: |
| 122 | + # apply the generic_function decorator to make open_memory_channel indexable |
| 123 | + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime |
| 124 | + open_memory_channel = generic_function(_open_memory_channel) |
| 125 | + |
| 126 | + |
107 | 127 | @attr.s(frozen=True, slots=True) |
108 | 128 | class MemoryChannelStats: |
109 | 129 | current_buffer_used: int = attr.ib() |
@@ -138,9 +158,7 @@ def statistics(self) -> MemoryChannelStats: |
138 | 158 |
|
139 | 159 |
|
140 | 160 | @attr.s(eq=False, repr=False) |
141 | | -class MemorySendChannel( |
142 | | - SendChannel[SendType], Generic[SendType], metaclass=NoPublicConstructor |
143 | | -): |
| 161 | +class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor): |
144 | 162 | _state: MemoryChannelState[SendType] = attr.ib() |
145 | 163 | _closed: bool = attr.ib(default=False) |
146 | 164 | # This is just the tasks waiting on *this* object. As compared to |
@@ -284,9 +302,7 @@ async def aclose(self) -> None: |
284 | 302 |
|
285 | 303 |
|
286 | 304 | @attr.s(eq=False, repr=False) |
287 | | -class MemoryReceiveChannel( |
288 | | - ReceiveChannel[ReceiveType], Generic[ReceiveType], metaclass=NoPublicConstructor |
289 | | -): |
| 305 | +class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor): |
290 | 306 | _state: MemoryChannelState[ReceiveType] = attr.ib() |
291 | 307 | _closed: bool = attr.ib(default=False) |
292 | 308 | _tasks: set[trio._core._run.Task] = attr.ib(factory=set) |
|
0 commit comments