Skip to content

Commit e8f5ef5

Browse files
committed
Merge branch 'main' into WSMR/update_who_has
2 parents cb286c3 + 5feb171 commit e8f5ef5

22 files changed

+555
-380
lines changed

distributed/cli/dask_scheduler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import warnings
1010

1111
import click
12-
from tornado.ioloop import IOLoop
1312

1413
from distributed import Scheduler
1514
from distributed._signals import wait_for_signals
@@ -186,11 +185,9 @@ def del_pid_file():
186185
resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))
187186

188187
async def run():
189-
loop = IOLoop.current()
190188
logger.info("-" * 47)
191189

192190
scheduler = Scheduler(
193-
loop=loop,
194191
security=sec,
195192
host=host,
196193
port=port,

distributed/cli/tests/test_dask_scheduler.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
assert_can_connect_from_everywhere_4_6,
2727
assert_can_connect_locally_4,
2828
popen,
29+
wait_for_log_line,
2930
)
3031

3132

@@ -66,12 +67,8 @@ def test_dashboard(loop):
6667
pytest.importorskip("bokeh")
6768

6869
with popen(["dask-scheduler"], flush_output=False) as proc:
69-
for line in proc.stdout:
70-
if b"dashboard at" in line:
71-
dashboard_port = int(line.decode().split(":")[-1].strip())
72-
break
73-
else:
74-
assert False # pragma: nocover
70+
line = wait_for_log_line(b"dashboard at", proc.stdout)
71+
dashboard_port = int(line.decode().split(":")[-1].strip())
7572

7673
with Client(f"127.0.0.1:{Scheduler.default_port}", loop=loop):
7774
pass
@@ -223,13 +220,9 @@ def test_dashboard_port_zero(loop):
223220
["dask-scheduler", "--dashboard-address", ":0"],
224221
flush_output=False,
225222
) as proc:
226-
for line in proc.stdout:
227-
if b"dashboard at" in line:
228-
dashboard_port = int(line.decode().split(":")[-1].strip())
229-
assert dashboard_port != 0
230-
break
231-
else:
232-
assert False # pragma: nocover
223+
line = wait_for_log_line(b"dashboard at", proc.stdout)
224+
dashboard_port = int(line.decode().split(":")[-1].strip())
225+
assert dashboard_port != 0
233226

234227

235228
PRELOAD_TEXT = """

distributed/cli/tests/test_dask_ssh.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from distributed import Client
55
from distributed.cli.dask_ssh import main
66
from distributed.compatibility import MACOS, WINDOWS
7-
from distributed.utils_test import popen
7+
from distributed.utils_test import popen, wait_for_log_line
88

99
pytest.importorskip("paramiko")
1010
pytestmark = [
@@ -30,16 +30,12 @@ def test_ssh_cli_nprocs_renamed_to_nworkers(loop):
3030
# This interrupt is necessary for the cluster to place output into the stdout
3131
# and stderr pipes
3232
proc.send_signal(2)
33-
assert any(
34-
b"renamed to --nworkers" in proc.stdout.readline() for _ in range(15)
35-
)
33+
wait_for_log_line(b"renamed to --nworkers", proc.stdout, max_lines=15)
3634

3735

3836
def test_ssh_cli_nworkers_with_nprocs_is_an_error():
3937
with popen(
4038
["dask-ssh", "localhost", "--nprocs=2", "--nworkers=2"],
4139
flush_output=False,
4240
) as proc:
43-
assert any(
44-
b"Both --nprocs and --nworkers" in proc.stdout.readline() for _ in range(15)
45-
)
41+
wait_for_log_line(b"Both --nprocs and --nworkers", proc.stdout, max_lines=15)

distributed/cli/tests/test_dask_worker.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from distributed.deploy.utils import nprocesses_nthreads
2020
from distributed.metrics import time
2121
from distributed.utils import open_port
22-
from distributed.utils_test import gen_cluster, popen, requires_ipv6
22+
from distributed.utils_test import gen_cluster, popen, requires_ipv6, wait_for_log_line
2323

2424

2525
@pytest.mark.parametrize(
@@ -246,9 +246,7 @@ async def test_nanny_worker_port_range_too_many_workers_raises(s):
246246
],
247247
flush_output=False,
248248
) as worker:
249-
assert any(
250-
b"Not enough ports in range" in worker.stdout.readline() for _ in range(100)
251-
)
249+
wait_for_log_line(b"Not enough ports in range", worker.stdout, max_lines=100)
252250

253251

254252
@pytest.mark.slow
@@ -282,26 +280,14 @@ async def test_reconnect_deprecated(c, s):
282280
["dask-worker", s.address, "--reconnect"],
283281
flush_output=False,
284282
) as worker:
285-
for _ in range(10):
286-
line = worker.stdout.readline()
287-
print(line)
288-
if b"`--reconnect` option has been removed" in line:
289-
break
290-
else:
291-
raise AssertionError("Message not printed, see stdout")
283+
wait_for_log_line(b"`--reconnect` option has been removed", worker.stdout)
292284
assert worker.wait() == 1
293285

294286
with popen(
295287
["dask-worker", s.address, "--no-reconnect"],
296288
flush_output=False,
297289
) as worker:
298-
for _ in range(10):
299-
line = worker.stdout.readline()
300-
print(line)
301-
if b"flag is deprecated, and will be removed" in line:
302-
break
303-
else:
304-
raise AssertionError("Message not printed, see stdout")
290+
wait_for_log_line(b"flag is deprecated, and will be removed", worker.stdout)
305291
await c.wait_for_workers(1)
306292
await c.shutdown()
307293

@@ -377,9 +363,7 @@ async def test_nworkers_requires_nanny(s):
377363
["dask-worker", s.address, "--nworkers=2", "--no-nanny"],
378364
flush_output=False,
379365
) as worker:
380-
assert any(
381-
b"Failed to launch worker" in worker.stdout.readline() for _ in range(15)
382-
)
366+
wait_for_log_line(b"Failed to launch worker", worker.stdout, max_lines=15)
383367

384368

385369
@pytest.mark.slow
@@ -419,9 +403,7 @@ async def test_worker_cli_nprocs_renamed_to_nworkers(c, s):
419403
flush_output=False,
420404
) as worker:
421405
await c.wait_for_workers(2)
422-
assert any(
423-
b"renamed to --nworkers" in worker.stdout.readline() for _ in range(15)
424-
)
406+
wait_for_log_line(b"renamed to --nworkers", worker.stdout, max_lines=15)
425407

426408

427409
@gen_cluster(nthreads=[])
@@ -430,10 +412,7 @@ async def test_worker_cli_nworkers_with_nprocs_is_an_error(s):
430412
["dask-worker", s.address, "--nprocs=2", "--nworkers=2"],
431413
flush_output=False,
432414
) as worker:
433-
assert any(
434-
b"Both --nprocs and --nworkers" in worker.stdout.readline()
435-
for _ in range(15)
436-
)
415+
wait_for_log_line(b"Both --nprocs and --nworkers", worker.stdout, max_lines=15)
437416

438417

439418
@pytest.mark.slow
@@ -733,12 +712,10 @@ def test_error_during_startup(monkeypatch, nanny):
733712
) as scheduler:
734713
start = time()
735714
# Wait for the scheduler to be up
736-
while line := scheduler.stdout.readline():
737-
if b"Scheduler at" in line:
738-
break
739-
# Ensure this is not killed by pytest-timeout
740-
if time() - start > 5:
741-
raise TimeoutError("Scheduler failed to start in time.")
715+
wait_for_log_line(b"Scheduler at", scheduler.stdout)
716+
# Ensure this is not killed by pytest-timeout
717+
if time() - start > 5:
718+
raise TimeoutError("Scheduler failed to start in time.")
742719

743720
with popen(
744721
[

distributed/comm/tests/test_comms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ async def client_communicate(key, delay=0):
563563

564564
@pytest.mark.gpu
565565
@gen_test()
566-
async def test_ucx_client_server():
566+
async def test_ucx_client_server(ucx_loop):
567567
pytest.importorskip("distributed.comm.ucx")
568568
ucp = pytest.importorskip("ucp")
569569

distributed/comm/tests/test_ucx.py

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,7 @@
2222
HOST = "127.0.0.1"
2323

2424

25-
def handle_exception(loop, context):
26-
msg = context.get("exception", context["message"])
27-
print(msg)
28-
29-
30-
# Let's make sure that UCX gets time to cancel
31-
# progress tasks before closing the event loop.
32-
@pytest.fixture()
33-
def event_loop(scope="function"):
34-
loop = asyncio.new_event_loop()
35-
loop.set_exception_handler(handle_exception)
36-
ucp.reset()
37-
yield loop
38-
ucp.reset()
39-
loop.run_until_complete(asyncio.sleep(0))
40-
loop.close()
41-
42-
43-
def test_registered():
25+
def test_registered(ucx_loop):
4426
assert "ucx" in backends
4527
backend = get_backend("ucx")
4628
assert isinstance(backend, ucx.UCXBackend)
@@ -62,7 +44,7 @@ async def handle_comm(comm):
6244

6345

6446
@gen_test()
65-
async def test_ping_pong():
47+
async def test_ping_pong(ucx_loop):
6648
com, serv_com = await get_comm_pair()
6749
msg = {"op": "ping"}
6850
await com.write(msg)
@@ -80,7 +62,7 @@ async def test_ping_pong():
8062

8163

8264
@gen_test()
83-
async def test_comm_objs():
65+
async def test_comm_objs(ucx_loop):
8466
comm, serv_comm = await get_comm_pair()
8567

8668
scheme, loc = parse_address(comm.peer_address)
@@ -93,7 +75,7 @@ async def test_comm_objs():
9375

9476

9577
@gen_test()
96-
async def test_ucx_specific():
78+
async def test_ucx_specific(ucx_loop):
9779
"""
9880
Test concrete UCX API.
9981
"""
@@ -147,7 +129,7 @@ async def client_communicate(key, delay=0):
147129

148130

149131
@gen_test()
150-
async def test_ping_pong_data():
132+
async def test_ping_pong_data(ucx_loop):
151133
np = pytest.importorskip("numpy")
152134

153135
data = np.ones((10, 10))
@@ -170,7 +152,7 @@ async def test_ping_pong_data():
170152

171153

172154
@gen_test()
173-
async def test_ucx_deserialize():
155+
async def test_ucx_deserialize(ucx_loop):
174156
# Note we see this error on some systems with this test:
175157
# `socket.gaierror: [Errno -5] No address associated with hostname`
176158
# This may be due to a system configuration issue.
@@ -196,7 +178,7 @@ async def test_ucx_deserialize():
196178
],
197179
)
198180
@gen_test()
199-
async def test_ping_pong_cudf(g):
181+
async def test_ping_pong_cudf(ucx_loop, g):
200182
# if this test appears after cupy an import error arises
201183
# *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11'
202184
# not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12)
@@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g):
221203

222204
@pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)])
223205
@gen_test()
224-
async def test_ping_pong_cupy(shape):
206+
async def test_ping_pong_cupy(ucx_loop, shape):
225207
cupy = pytest.importorskip("cupy")
226208
com, serv_com = await get_comm_pair()
227209

@@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape):
240222
@pytest.mark.slow
241223
@pytest.mark.parametrize("n", [int(1e9), int(2.5e9)])
242224
@gen_test()
243-
async def test_large_cupy(n, cleanup):
225+
async def test_large_cupy(ucx_loop, n, cleanup):
244226
cupy = pytest.importorskip("cupy")
245227
com, serv_com = await get_comm_pair()
246228

@@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup):
257239

258240

259241
@gen_test()
260-
async def test_ping_pong_numba():
242+
async def test_ping_pong_numba(ucx_loop):
261243
np = pytest.importorskip("numpy")
262244
numba = pytest.importorskip("numba")
263245
import numba.cuda
@@ -276,7 +258,7 @@ async def test_ping_pong_numba():
276258

277259
@pytest.mark.parametrize("processes", [True, False])
278260
@gen_test()
279-
async def test_ucx_localcluster(processes, cleanup):
261+
async def test_ucx_localcluster(ucx_loop, processes, cleanup):
280262
async with LocalCluster(
281263
protocol="ucx",
282264
host=HOST,
@@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup):
297279

298280
@pytest.mark.slow
299281
@gen_test(timeout=60)
300-
async def test_stress():
282+
async def test_stress(
283+
ucx_loop,
284+
):
301285
da = pytest.importorskip("dask.array")
302286

303287
chunksize = "10 MB"
@@ -322,15 +306,19 @@ async def test_stress():
322306

323307

324308
@gen_test()
325-
async def test_simple():
309+
async def test_simple(
310+
ucx_loop,
311+
):
326312
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
327313
async with Client(cluster, asynchronous=True) as client:
328314
assert cluster.scheduler_address.startswith("ucx://")
329315
assert await client.submit(lambda x: x + 1, 10) == 11
330316

331317

332318
@gen_test()
333-
async def test_cuda_context():
319+
async def test_cuda_context(
320+
ucx_loop,
321+
):
334322
with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}):
335323
async with LocalCluster(
336324
protocol="ucx", n_workers=1, asynchronous=True
@@ -344,7 +332,9 @@ async def test_cuda_context():
344332

345333

346334
@gen_test()
347-
async def test_transpose():
335+
async def test_transpose(
336+
ucx_loop,
337+
):
348338
da = pytest.importorskip("dask.array")
349339

350340
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
@@ -358,7 +348,7 @@ async def test_transpose():
358348

359349
@pytest.mark.parametrize("port", [0, 1234])
360350
@gen_test()
361-
async def test_ucx_protocol(cleanup, port):
351+
async def test_ucx_protocol(ucx_loop, cleanup, port):
362352
async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s:
363353
assert s.address.startswith("ucx://")
364354

@@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port):
367357
not hasattr(ucp.exceptions, "UCXUnreachable"),
368358
reason="Requires UCX-Py support for UCXUnreachable exception",
369359
)
370-
def test_ucx_unreachable():
371-
if ucp.get_ucx_version() > (1, 12, 0):
372-
with pytest.raises(OSError, match="Timed out trying to connect to"):
373-
Client("ucx://255.255.255.255:12345", timeout=1)
374-
else:
375-
with pytest.raises(ucp.exceptions.UCXError, match="Destination is unreachable"):
376-
Client("ucx://255.255.255.255:12345", timeout=1)
360+
@gen_test()
361+
async def test_ucx_unreachable(
362+
ucx_loop,
363+
):
364+
with pytest.raises(OSError, match="Timed out trying to connect to"):
365+
await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True)

0 commit comments

Comments
 (0)