Skip to content

Commit 42d399f

Browse files
committed
Merge branch 'main' into WSMR/refactor_handlers
2 parents 249e819 + 7d280fd commit 42d399f

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

distributed/comm/ucx.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,19 @@ def init_once():
125125

126126
ucp.init(options=ucx_config, env_takes_precedence=True)
127127

128+
pool_size_str = dask.config.get("distributed.rmm.pool-size")
129+
128130
# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
129131
try:
130132
import rmm
131133

132134
device_array = lambda n: rmm.DeviceBuffer(size=n)
135+
136+
if pool_size_str is not None:
137+
pool_size = parse_bytes(pool_size_str)
138+
rmm.reinitialize(
139+
pool_allocator=True, managed_memory=False, initial_pool_size=pool_size
140+
)
133141
except ImportError:
134142
try:
135143
import numba.cuda
@@ -140,19 +148,19 @@ def numba_device_array(n):
140148
return a
141149

142150
device_array = numba_device_array
151+
143152
except ImportError:
144153

145154
def device_array(n):
146155
raise RuntimeError(
147156
"In order to send/recv CUDA arrays, Numba or RMM is required"
148157
)
149158

150-
pool_size_str = dask.config.get("distributed.rmm.pool-size")
151-
if pool_size_str is not None:
152-
pool_size = parse_bytes(pool_size_str)
153-
rmm.reinitialize(
154-
pool_allocator=True, managed_memory=False, initial_pool_size=pool_size
155-
)
159+
if pool_size_str is not None:
160+
warnings.warn(
161+
"Initial RMM pool size defined, but RMM is not available. "
162+
"Please consider installing RMM or removing the pool size option."
163+
)
156164

157165

158166
def _close_comm(ref):

0 commit comments

Comments
 (0)