Skip to content

Commit eb1b587

Browse files
authored
Merge pull request #150 from proxystore/issue-134
Add `async_resolve` option to `ProxyTransformer` and make `RunConfig.env_vars` optional
2 parents d82c8da + f6ed700 commit eb1b587

File tree

5 files changed

+104
-14
lines changed

5 files changed

+104
-14
lines changed

taps/run/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class RunConfig(BaseModel):
5959
'"{executor}" for formatting).'
6060
),
6161
)
62-
env_vars: Dict[str, str] = Field( # noqa: UP006
63-
default_factory=dict,
62+
env_vars: Optional[Dict[str, str]] = Field( # noqa: UP006,UP007
63+
None,
6464
description='Environment variables to set during benchmarking.',
6565
)
6666

taps/run/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def run(config: Config, run_dir: pathlib.Path) -> None:
8080
config.write_toml(CONFIG_FILENAME)
8181
logger.debug(f'Wrote config to {CONFIG_FILENAME}')
8282

83-
with update_environment(config.run.env_vars):
83+
env_vars = config.run.env_vars if config.run.env_vars is not None else {}
84+
with update_environment(env_vars):
8485
with Timer() as app_init_timer:
8586
app = config.app.get_app()
8687
logger.log(

taps/run/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def update_environment(
187187
os.environ.update(variables)
188188
if len(variables) > 0:
189189
logger.debug(
190-
f'Updated {len(variables)} environment variable(s) '
191-
f'({", ".join(variables.keys())})',
190+
f'Updated {len(variables)} environment variable(s): '
191+
f'{", ".join(variables.keys())}',
192192
)
193193

194194
try:
@@ -200,6 +200,6 @@ def update_environment(
200200
os.environ.update(previous)
201201
if len(previous) > 0:
202202
logger.debug(
203-
f'Restored {len(previous)} environment variable(s) '
204-
f'({",".join(previous.keys())})',
203+
f'Restored {len(previous)} environment variable(s): '
204+
f'{",".join(previous.keys())}',
205205
)

taps/transformer/_proxy.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
from __future__ import annotations
22

3+
import sys
34
from typing import Any
45
from typing import Literal
56
from typing import TypeVar
67

8+
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
9+
from typing import Self
10+
else: # pragma: <3.11 cover
11+
from typing_extensions import Self
12+
713
from proxystore.proxy import extract
814
from proxystore.proxy import Proxy
915
from proxystore.store import get_store
1016
from proxystore.store import Store
1117
from proxystore.store.config import ConnectorConfig
18+
from proxystore.store.utils import resolve_async
1219
from pydantic import Field
20+
from pydantic import model_validator
1321

1422
from taps.plugins import register
1523
from taps.transformer._protocol import TransformerConfig
@@ -29,17 +37,34 @@ class ProxyTransformerConfig(TransformerConfig):
2937
description='Connector configuration.',
3038
)
3139
cache_size: int = Field(16, description='cache size')
40+
async_resolve: bool = Field(
41+
False,
42+
description=(
43+
'Asynchronously resolve proxies. Not compatible with '
44+
'extract_target=True.'
45+
),
46+
)
3247
extract_target: bool = Field(
3348
False,
3449
description=(
35-
'Extract the target from the proxy when resolving the identifier.'
50+
'Extract the target from the proxy when resolving the identifier. '
51+
'Not compatible with async_resolve=True.'
3652
),
3753
)
3854
populate_target: bool = Field(
3955
True,
4056
description='Populate target objects of newly created proxies.',
4157
)
4258

59+
@model_validator(mode='after')
60+
def _validate_mutex_options(self) -> Self:
61+
if self.async_resolve and self.extract_target:
62+
raise ValueError(
63+
'Options async_resolve and extract_target cannot be '
64+
'enabled at the same time.',
65+
)
66+
return self
67+
4368
def get_transformer(self) -> ProxyTransformer:
4469
"""Create a transformer from the configuration."""
4570
connector = self.connector.get_connector()
@@ -51,6 +76,7 @@ def get_transformer(self) -> ProxyTransformer:
5176
populate_target=self.populate_target,
5277
register=True,
5378
),
79+
async_resolve=self.async_resolve,
5480
extract_target=self.extract_target,
5581
)
5682

@@ -62,29 +88,43 @@ class ProxyTransformer:
6288
6389
Args:
6490
store: Store instance to use for proxying objects.
91+
async_resolve: Begin asynchronously resolving proxies when the
92+
transformer resolves a proxy (which is otherwise a no-op unless
93+
`extract_target=True`). Not compatible with `extract_target=True`.
6594
extract_target: When `True`, resolving an identifier (i.e., a proxy)
6695
will return the target object. Otherwise, the proxy is returned
67-
since a proxy can act as the target object.
96+
since a proxy can act as the target object. Not compatible
97+
with `async_resolve=True`.
6898
"""
6999

70100
def __init__(
71101
self,
72102
store: Store[Any],
73103
*,
104+
async_resolve: bool = False,
74105
extract_target: bool = False,
75106
) -> None:
107+
if async_resolve and extract_target:
108+
raise ValueError(
109+
'Options async_resolve and extract_target cannot be '
110+
'enabled at the same time.',
111+
)
112+
76113
self.store = store
114+
self.async_resolve = async_resolve
77115
self.extract_target = extract_target
78116

79117
def __repr__(self) -> str:
80118
ctype = type(self).__name__
81119
store = f'store={self.store}'
120+
async_ = f'async_resolve={self.async_resolve}'
82121
extract = f'extract_target={self.extract_target}'
83-
return f'{ctype}({store}, {extract})'
122+
return f'{ctype}({store}, {async_}, {extract})'
84123

85124
def __getstate__(self) -> dict[str, Any]:
86125
return {
87126
'config': self.store.config(),
127+
'async_resolve': self.async_resolve,
88128
'extract_target': self.extract_target,
89129
}
90130

@@ -94,6 +134,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
94134
self.store = store
95135
else:
96136
self.store = Store.from_config(state['config'])
137+
self.async_resolve = state['async_resolve']
97138
self.extract_target = state['extract_target']
98139

99140
def close(self) -> None:
@@ -125,4 +166,8 @@ def resolve(self, identifier: Proxy[T]) -> T | Proxy[T]:
125166
The resolved object or a proxy of the resolved object depending \
126167
on the setting of `extract_target`.
127168
"""
128-
return extract(identifier) if self.extract_target else identifier
169+
if self.extract_target:
170+
return extract(identifier)
171+
if self.async_resolve:
172+
resolve_async(identifier)
173+
return identifier

tests/transformer/proxy_test.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from proxystore.store import Store
1111
from proxystore.store import unregister_store
1212
from proxystore.store.config import ConnectorConfig
13+
from pydantic import ValidationError
1314

1415
from taps.transformer import ProxyTransformer
1516
from taps.transformer import ProxyTransformerConfig
@@ -26,14 +27,40 @@ def test_file_config(tmp_path: pathlib.Path) -> None:
2627
transformer.close()
2728

2829

29-
@pytest.mark.parametrize('extract', (True, False))
30-
def test_proxy_transformer(extract: bool) -> None:
30+
def test_config_validation_error(tmp_path: pathlib.Path) -> None:
31+
with pytest.raises(
32+
ValidationError,
33+
match='Options async_resolve and extract_target cannot be enabled',
34+
):
35+
ProxyTransformerConfig(
36+
connector=ConnectorConfig(
37+
kind='file',
38+
options={'store_dir': str(tmp_path)},
39+
),
40+
async_resolve=True,
41+
extract_target=True,
42+
)
43+
44+
45+
@pytest.mark.parametrize(
46+
('extract', 'async_'),
47+
(
48+
(False, False),
49+
(True, False),
50+
(False, True),
51+
),
52+
)
53+
def test_proxy_transformer(extract: bool, async_: bool) -> None:
3154
with Store(
3255
'test-proxy-transformer',
3356
LocalConnector(),
3457
register=True,
3558
) as store:
36-
transformer = ProxyTransformer(store, extract_target=extract)
59+
transformer = ProxyTransformer(
60+
store,
61+
async_resolve=async_,
62+
extract_target=extract,
63+
)
3764
assert isinstance(repr(transformer), str)
3865

3966
obj = [1, 2, 3]
@@ -46,6 +73,23 @@ def test_proxy_transformer(extract: bool) -> None:
4673
transformer.close()
4774

4875

76+
def test_proxy_transformer_value_error() -> None:
77+
with Store(
78+
'test-proxy-transformer-value-error',
79+
LocalConnector(),
80+
register=True,
81+
) as store:
82+
with pytest.raises(
83+
ValueError,
84+
match='Options async_resolve and extract_target cannot be enabled',
85+
):
86+
ProxyTransformer(
87+
store,
88+
async_resolve=True,
89+
extract_target=True,
90+
)
91+
92+
4993
def test_proxy_transformer_pickling() -> None:
5094
name = 'test-proxy-transformer-pickle'
5195
with Store(name, LocalConnector(), register=True) as store:

0 commit comments

Comments
 (0)