File tree Expand file tree Collapse file tree 1 file changed +14
-6
lines changed
Expand file tree Collapse file tree 1 file changed +14
-6
lines changed Original file line number Diff line number Diff line change @@ -125,11 +125,19 @@ def init_once():
125125
126126 ucp .init (options = ucx_config , env_takes_precedence = True )
127127
128+ pool_size_str = dask .config .get ("distributed.rmm.pool-size" )
129+
128130 # Find the function, `cuda_array()`, to use when allocating new CUDA arrays
129131 try :
130132 import rmm
131133
132134 device_array = lambda n : rmm .DeviceBuffer (size = n )
135+
136+ if pool_size_str is not None :
137+ pool_size = parse_bytes (pool_size_str )
138+ rmm .reinitialize (
139+ pool_allocator = True , managed_memory = False , initial_pool_size = pool_size
140+ )
133141 except ImportError :
134142 try :
135143 import numba .cuda
@@ -140,19 +148,19 @@ def numba_device_array(n):
140148 return a
141149
142150 device_array = numba_device_array
151+
143152 except ImportError :
144153
145154 def device_array (n ):
146155 raise RuntimeError (
147156 "In order to send/recv CUDA arrays, Numba or RMM is required"
148157 )
149158
150- pool_size_str = dask .config .get ("distributed.rmm.pool-size" )
151- if pool_size_str is not None :
152- pool_size = parse_bytes (pool_size_str )
153- rmm .reinitialize (
154- pool_allocator = True , managed_memory = False , initial_pool_size = pool_size
155- )
159+ if pool_size_str is not None :
160+ warnings .warn (
161+ "Initial RMM pool size defined, but RMM is not available. "
162+ "Please consider installing RMM or removing the pool size option."
163+ )
156164
157165
158166def _close_comm (ref ):
You can’t perform that action at this time.
0 commit comments