Skip to content

Commit a395f09

Browse files
feat: allow data function in file download to be a coroutine
1 parent 3647cf4 commit a395f09

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

solara/components/file_download.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import asyncio
12
from pathlib import Path
2-
from typing import BinaryIO, Callable, List, Optional, Union, cast
3+
from typing import Any, BinaryIO, Callable, Coroutine, List, Optional, Union, cast
34

45
import ipyvuetify as vy
56
import ipywidgets as widgets
@@ -19,7 +20,13 @@ class FileDownloadWidget(vy.VuetifyTemplate):
1920

2021
@solara.component
2122
def FileDownload(
22-
data: Union[str, bytes, BinaryIO, Callable[[], Union[str, bytes, BinaryIO]]],
23+
data: Union[
24+
str,
25+
bytes,
26+
BinaryIO,
27+
Callable[[], Union[str, bytes, BinaryIO]],
28+
Callable[[], Coroutine[Any, Any, Union[str, bytes, BinaryIO]]],
29+
],
2330
filename: Optional[str] = None,
2431
label: Optional[str] = None,
2532
icon_name: Optional[str] = "mdi-cloud-download-outline",
@@ -130,7 +137,7 @@ def get_data():
130137
131138
## Arguments
132139
133-
* `data`: The data to download. Can be a string, bytes, or a file like object, or a function that returns one of these.
140+
* `data`: The data to download. Can be a string, bytes, or a file like object, or a function (or coroutine function) that returns one of these.
134141
* `filename`: The name of the file the user will see as default when downloading (default name is "solara-download.dat").
135142
If a file object is provided, the filename will be extracted from the file object if possible.
136143
* `label`: The label of the button. If not provided, the label will be "Download: {filename}".
@@ -162,7 +169,11 @@ def reset():
162169
def get_data() -> Optional[bytes]:
163170
if request_download:
164171
if callable(data):
165-
data_non_lazy = data()
172+
# if it is a coroutine, we start a new event loop and run it
173+
if asyncio.iscoroutinefunction(data):
174+
data_non_lazy = asyncio.run(data())
175+
else:
176+
data_non_lazy = data()
166177
else:
167178
data_non_lazy = data
168179
if hasattr(data_non_lazy, "read"):

0 commit comments

Comments
 (0)