Skip to content

Commit 3a97835

Browse files
feat: allows tasks to have their own kernel context
Adding this so we can use solara tasks and reactive variables in fastapi endpoint, which is only async.
1 parent 7c239aa commit 3a97835

File tree

5 files changed

+219
-4
lines changed

5 files changed

+219
-4
lines changed

solara/server/kernel_context.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import time
1818
import typing
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, cast
20+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
2121

2222
import ipywidgets as widgets
2323
import reacton
@@ -39,6 +39,11 @@ class Local(threading.local):
3939

4040

4141
local = Local()
42+
# same idea, but for `async with ...`
43+
if typing.TYPE_CHECKING:
44+
async_stack = contextvars.ContextVar[Union[Tuple[Union[None, "VirtualKernelContext"], ...], None]](name="async_stack", default=None)
45+
else:
46+
async_stack = contextvars.ContextVar("async_stack", default=None)
4247

4348

4449
class PageStatus(enum.Enum):
@@ -100,6 +105,23 @@ def display(self, *args):
100105
def on_close(self, f: Callable[[], None]):
101106
self._on_close_callbacks.append(f)
102107

108+
async def __aenter__(self):
109+
stack = async_stack.get()
110+
if stack is None:
111+
stack = ()
112+
key = get_current_thread_key()
113+
async_stack.set(stack + (current_context.get(key, None),))
114+
new_key = get_current_thread_key()
115+
current_context[new_key] = self
116+
117+
async def __aexit__(self, *args):
118+
key = get_current_thread_key()
119+
assert local.kernel_context_stack is not None
120+
stack = async_stack.get()
121+
assert stack is not None
122+
current_context[key] = stack[-1]
123+
async_stack.set(stack[:-1])
124+
103125
def __enter__(self):
104126
if local.kernel_context_stack is None:
105127
local.kernel_context_stack = []
@@ -364,6 +386,7 @@ def create_dummy_context():
364386

365387

366388
def get_current_thread_key() -> str:
389+
# consider renaming this to get_current_context_key
367390
if not solara.server.settings.kernel.threaded:
368391
if async_context_id is not None:
369392
try:
@@ -375,6 +398,13 @@ def get_current_thread_key() -> str:
375398
else:
376399
thread = threading.current_thread()
377400
key = get_thread_key(thread)
401+
# this signals we are using `async with context`, which means we are interested in task-local context
402+
stack = async_stack.get()
403+
if stack is not None and len(stack) > 0:
404+
current_task = asyncio.current_task()
405+
if current_task is not None:
406+
task_key = current_task.get_name()
407+
key = f"{key}-task:{task_key}"
378408
return key
379409

380410

solara/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class MainSettings(BaseSettings):
5757
allow_reactive_boolean: bool = True
5858
# TODO: also change default_container in solara/components/__init__.py
5959
default_container: Optional[str] = "Column"
60+
allow_global_context: bool = True
6061

6162
class Config:
6263
env_prefix = "solara_"

solara/tasks.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextvars
12
import sys
23
import abc
34
import asyncio
@@ -267,7 +268,16 @@ def cancel():
267268

268269
if self.run_in_thread:
269270
thread_event_loop = asyncio.new_event_loop()
270-
self.current_task = current_task = thread_event_loop.create_task(self._async_run(call_event_loop, future, args, kwargs))
271+
272+
def create_task():
273+
# remove the stack, since this thread starts with a fresh stack
274+
import solara.server.kernel_context
275+
276+
solara.server.kernel_context.async_stack.set(None)
277+
return thread_event_loop.create_task(self._async_run(call_event_loop, future, args, kwargs))
278+
279+
new_context = contextvars.copy_context()
280+
self.current_task = current_task = new_context.run(create_task)
271281

272282
def runs_in_thread():
273283
try:
@@ -298,7 +308,7 @@ def runs_in_thread():
298308
raise
299309

300310
self._result.value = TaskResult[R](latest=self._last_value, _state=TaskState.STARTING)
301-
thread = threading.Thread(target=runs_in_thread, daemon=True)
311+
thread = threading.Thread(target=runs_in_thread, daemon=True, name=f"TaskAsyncio-{self.function.__name__}")
302312
thread.start()
303313
else:
304314
self.current_task = current_task = asyncio.create_task(self._async_run(call_event_loop, future, args, kwargs))
@@ -322,7 +332,6 @@ async def _async_run(self, call_event_loop: asyncio.AbstractEventLoop, future: a
322332

323333
task_for_this_call = _get_current_task()
324334
assert task_for_this_call is not None
325-
326335
if self.is_current():
327336
self._result.value = TaskResult[R](latest=self._last_value, _state=TaskState.STARTING)
328337

solara/toestand.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def clear(self):
313313

314314
def set(self, value: S):
315315
scope_dict, scope_id = self._get_dict()
316+
if not solara.settings.main.allow_global_context and scope_id == "global":
317+
raise RuntimeError(
318+
f"No kernel context found, and global context is not allowed for task, context key was {solara.server.kernel_context.get_current_thread_key()}"
319+
)
316320
old = self.get()
317321
if self.equals(old, value):
318322
return

tests/unit/async_test.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import asyncio
2+
from typing import Tuple
3+
import solara
4+
from solara.server import kernel_context
5+
from solara.server import kernel
6+
from unittest import mock
7+
import solara.lab
8+
9+
reactive = solara.reactive(0)
10+
test_async_task_setting = solara.reactive(0)
11+
tasks = [] # always keep a reference to an asyncio.Task
12+
13+
14+
@solara.lab.task # (prefer_threaded=True)
15+
async def multiply_by(value: int):
16+
result = reactive.value * value
17+
18+
# we also test that a task created in a new thread
19+
async def set_in_task():
20+
test_async_task_setting.value = result
21+
22+
task = asyncio.create_task(set_in_task())
23+
tasks.append(task)
24+
await asyncio.sleep(0.1) # give the task a chance to run
25+
await task # wait for the task to finish
26+
return result
27+
28+
29+
@mock.patch("solara._using_solara_server", return_value=True)
30+
async def test_async_kernels_basic(_):
31+
assert _() is True
32+
kernel1 = kernel.Kernel()
33+
kernel2 = kernel.Kernel()
34+
context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel1, session_id="session-1")
35+
context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel2, session_id="session-2")
36+
37+
values = solara.Reactive[Tuple[int, ...]]((1,))
38+
39+
async def task1():
40+
async with context1:
41+
for i in range(99):
42+
values.value = values.value + (len(values.value),)
43+
await asyncio.sleep(0.01)
44+
45+
async def task2():
46+
async with context2:
47+
for i in range(99):
48+
values.value = values.value + (len(values.value),)
49+
await asyncio.sleep(0.01)
50+
51+
# await asyncio.gather(asyncio.create_task(task1(), name="test-task1"), asyncio.create_task(task2(), name="test-task2"))
52+
await asyncio.gather(task1(), task2())
53+
assert values.value == (1,)
54+
with context1:
55+
assert len(values.value) == 100
56+
assert values.value[-1] == 99
57+
with context2:
58+
assert len(values.value) == 100
59+
assert values.value[-1] == 99
60+
assert values.value == (1,)
61+
async with context1:
62+
assert len(values.value) == 100
63+
assert values.value[-1] == 99
64+
async with context2:
65+
assert len(values.value) == 100
66+
assert values.value[-1] == 99
67+
assert values.value == (1,)
68+
69+
70+
@mock.patch("solara._using_solara_server", return_value=True)
71+
async def test_async_kernels_complex(_):
72+
assert _() is True
73+
event1 = asyncio.Event() # after event, global is 1
74+
event2 = asyncio.Event() # after event, global is still 1
75+
event3 = asyncio.Event() # after event, global is 2
76+
event4 = asyncio.Event() # after event, global is 3
77+
event5 = asyncio.Event() # after event, global is 3
78+
kernel1 = kernel.Kernel()
79+
kernel2 = kernel.Kernel()
80+
context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel1, session_id="session-1")
81+
context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel2, session_id="session-2")
82+
83+
main_thread_key = kernel_context.get_current_thread_key()
84+
85+
async def task1():
86+
# global default scope
87+
reactive.value = 1
88+
event1.set()
89+
async with context1:
90+
# kernel scope
91+
reactive.value = 100
92+
assert reactive.value == 100
93+
await event3.wait()
94+
assert reactive.value == 100
95+
multiply_by(3) # result should be 300
96+
97+
async def task2():
98+
await event2.wait()
99+
# global default scope
100+
assert reactive.value == 1
101+
reactive.value = 2
102+
multiply_by(8) # result should be 16
103+
event3.set()
104+
assert kernel_context.get_current_thread_key() == main_thread_key
105+
async with context2:
106+
assert main_thread_key in kernel_context.get_current_thread_key() and len(kernel_context.get_current_thread_key()) > len(main_thread_key)
107+
# kernel scope
108+
assert reactive.value == 0 # still the default value
109+
reactive.value = 200
110+
event4.set()
111+
multiply_by(3) # result should be 600
112+
113+
assert kernel_context.get_current_thread_key() == main_thread_key
114+
await event5.wait()
115+
assert reactive.value == 3
116+
117+
async def test():
118+
await event1.wait()
119+
assert reactive.value == 1
120+
event2.set()
121+
await event2.wait()
122+
# still global default scope
123+
assert reactive.value == 1
124+
await event3.wait()
125+
assert reactive.value == 2
126+
await event4.wait()
127+
await multiply_by.current_future # type: ignore
128+
reactive.value = 3
129+
event5.set()
130+
131+
await asyncio.gather(task1(), task2(), test())
132+
133+
with context1:
134+
assert reactive.value == 100
135+
with context2:
136+
assert reactive.value == 200
137+
assert reactive.value == 3
138+
139+
# checking task results
140+
while not multiply_by.result.finished:
141+
await asyncio.sleep(0.1)
142+
assert multiply_by.result.value == 16
143+
assert test_async_task_setting.value == 16
144+
with context1:
145+
while not multiply_by.result.finished:
146+
await asyncio.sleep(0.1)
147+
assert multiply_by.result.value == 300
148+
assert test_async_task_setting.value == 300
149+
with context2:
150+
while not multiply_by.result.finished:
151+
await asyncio.sleep(0.1)
152+
assert multiply_by.result.value == 600
153+
assert test_async_task_setting.value == 600
154+
155+
156+
@mock.patch("solara._using_solara_server", return_value=True)
157+
async def test_async_kernels_task(_):
158+
assert _() is True
159+
kernel1 = kernel.Kernel()
160+
context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel1, session_id="session-1")
161+
162+
main_thread_key = kernel_context.get_current_thread_key()
163+
164+
async with context1:
165+
assert main_thread_key != kernel_context.get_current_thread_key()
166+
assert "task" in kernel_context.get_current_thread_key()
167+
reactive.value = 100
168+
multiply_by(3) # result should be 300
169+
assert reactive.value == 100
170+
await multiply_by.current_future # type: ignore
171+
assert test_async_task_setting.value == 300, "if this fails, the solara task was using the global context"

0 commit comments

Comments
 (0)