2222 HOST = "127.0.0.1"
2323
2424
25- def handle_exception (loop , context ):
26- msg = context .get ("exception" , context ["message" ])
27- print (msg )
28-
29-
30- # Let's make sure that UCX gets time to cancel
31- # progress tasks before closing the event loop.
32- @pytest .fixture ()
33- def event_loop (scope = "function" ):
34- loop = asyncio .new_event_loop ()
35- loop .set_exception_handler (handle_exception )
36- ucp .reset ()
37- yield loop
38- ucp .reset ()
39- loop .run_until_complete (asyncio .sleep (0 ))
40- loop .close ()
41-
42-
43- def test_registered ():
25+ def test_registered (ucx_loop ):
4426 assert "ucx" in backends
4527 backend = get_backend ("ucx" )
4628 assert isinstance (backend , ucx .UCXBackend )
@@ -62,7 +44,7 @@ async def handle_comm(comm):
6244
6345
6446@gen_test ()
65- async def test_ping_pong ():
47+ async def test_ping_pong (ucx_loop ):
6648 com , serv_com = await get_comm_pair ()
6749 msg = {"op" : "ping" }
6850 await com .write (msg )
@@ -80,7 +62,7 @@ async def test_ping_pong():
8062
8163
8264@gen_test ()
83- async def test_comm_objs ():
65+ async def test_comm_objs (ucx_loop ):
8466 comm , serv_comm = await get_comm_pair ()
8567
8668 scheme , loc = parse_address (comm .peer_address )
@@ -93,7 +75,7 @@ async def test_comm_objs():
9375
9476
9577@gen_test ()
96- async def test_ucx_specific ():
78+ async def test_ucx_specific (ucx_loop ):
9779 """
9880 Test concrete UCX API.
9981 """
@@ -147,7 +129,7 @@ async def client_communicate(key, delay=0):
147129
148130
149131@gen_test ()
150- async def test_ping_pong_data ():
132+ async def test_ping_pong_data (ucx_loop ):
151133 np = pytest .importorskip ("numpy" )
152134
153135 data = np .ones ((10 , 10 ))
@@ -170,7 +152,7 @@ async def test_ping_pong_data():
170152
171153
172154@gen_test ()
173- async def test_ucx_deserialize ():
155+ async def test_ucx_deserialize (ucx_loop ):
174156 # Note we see this error on some systems with this test:
175157 # `socket.gaierror: [Errno -5] No address associated with hostname`
176158 # This may be due to a system configuration issue.
@@ -196,7 +178,7 @@ async def test_ucx_deserialize():
196178 ],
197179)
198180@gen_test ()
199- async def test_ping_pong_cudf (g ):
181+ async def test_ping_pong_cudf (ucx_loop , g ):
200182 # if this test appears after cupy an import error arises
201183 # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11'
202184 # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12)
@@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g):
221203
222204@pytest .mark .parametrize ("shape" , [(100 ,), (10 , 10 ), (4947 ,)])
223205@gen_test ()
224- async def test_ping_pong_cupy (shape ):
206+ async def test_ping_pong_cupy (ucx_loop , shape ):
225207 cupy = pytest .importorskip ("cupy" )
226208 com , serv_com = await get_comm_pair ()
227209
@@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape):
240222@pytest .mark .slow
241223@pytest .mark .parametrize ("n" , [int (1e9 ), int (2.5e9 )])
242224@gen_test ()
243- async def test_large_cupy (n , cleanup ):
225+ async def test_large_cupy (ucx_loop , n , cleanup ):
244226 cupy = pytest .importorskip ("cupy" )
245227 com , serv_com = await get_comm_pair ()
246228
@@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup):
257239
258240
259241@gen_test ()
260- async def test_ping_pong_numba ():
242+ async def test_ping_pong_numba (ucx_loop ):
261243 np = pytest .importorskip ("numpy" )
262244 numba = pytest .importorskip ("numba" )
263245 import numba .cuda
@@ -276,7 +258,7 @@ async def test_ping_pong_numba():
276258
277259@pytest .mark .parametrize ("processes" , [True , False ])
278260@gen_test ()
279- async def test_ucx_localcluster (processes , cleanup ):
261+ async def test_ucx_localcluster (ucx_loop , processes , cleanup ):
280262 async with LocalCluster (
281263 protocol = "ucx" ,
282264 host = HOST ,
@@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup):
297279
298280@pytest .mark .slow
299281@gen_test (timeout = 60 )
300- async def test_stress ():
282+ async def test_stress (
283+ ucx_loop ,
284+ ):
301285 da = pytest .importorskip ("dask.array" )
302286
303287 chunksize = "10 MB"
@@ -322,15 +306,19 @@ async def test_stress():
322306
323307
324308@gen_test ()
325- async def test_simple ():
309+ async def test_simple (
310+ ucx_loop ,
311+ ):
326312 async with LocalCluster (protocol = "ucx" , asynchronous = True ) as cluster :
327313 async with Client (cluster , asynchronous = True ) as client :
328314 assert cluster .scheduler_address .startswith ("ucx://" )
329315 assert await client .submit (lambda x : x + 1 , 10 ) == 11
330316
331317
332318@gen_test ()
333- async def test_cuda_context ():
319+ async def test_cuda_context (
320+ ucx_loop ,
321+ ):
334322 with dask .config .set ({"distributed.comm.ucx.create-cuda-context" : True }):
335323 async with LocalCluster (
336324 protocol = "ucx" , n_workers = 1 , asynchronous = True
@@ -344,7 +332,9 @@ async def test_cuda_context():
344332
345333
346334@gen_test ()
347- async def test_transpose ():
335+ async def test_transpose (
336+ ucx_loop ,
337+ ):
348338 da = pytest .importorskip ("dask.array" )
349339
350340 async with LocalCluster (protocol = "ucx" , asynchronous = True ) as cluster :
@@ -358,7 +348,7 @@ async def test_transpose():
358348
359349@pytest .mark .parametrize ("port" , [0 , 1234 ])
360350@gen_test ()
361- async def test_ucx_protocol (cleanup , port ):
351+ async def test_ucx_protocol (ucx_loop , cleanup , port ):
362352 async with Scheduler (protocol = "ucx" , port = port , dashboard_address = ":0" ) as s :
363353 assert s .address .startswith ("ucx://" )
364354
@@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port):
367357 not hasattr (ucp .exceptions , "UCXUnreachable" ),
368358 reason = "Requires UCX-Py support for UCXUnreachable exception" ,
369359)
370- def test_ucx_unreachable ():
371- if ucp .get_ucx_version () > (1 , 12 , 0 ):
372- with pytest .raises (OSError , match = "Timed out trying to connect to" ):
373- Client ("ucx://255.255.255.255:12345" , timeout = 1 )
374- else :
375- with pytest .raises (ucp .exceptions .UCXError , match = "Destination is unreachable" ):
376- Client ("ucx://255.255.255.255:12345" , timeout = 1 )
360+ @gen_test ()
361+ async def test_ucx_unreachable (
362+ ucx_loop ,
363+ ):
364+ with pytest .raises (OSError , match = "Timed out trying to connect to" ):
365+ await Client ("ucx://255.255.255.255:12345" , timeout = 1 , asynchronous = True )
0 commit comments