|
5 | 5 |
|
6 | 6 | from contextvars import ContextVar |
7 | 7 | from unittest import mock |
8 | | -from test.test_asyncio import utils as test_utils |
9 | 8 |
|
10 | 9 |
|
11 | 10 | def tearDownModule(): |
12 | 11 | asyncio.set_event_loop_policy(None) |
13 | 12 |
|
14 | 13 |
|
15 | | -class ToThreadTests(test_utils.TestCase): |
16 | | - def setUp(self): |
17 | | - super().setUp() |
18 | | - self.loop = asyncio.new_event_loop() |
19 | | - asyncio.set_event_loop(self.loop) |
20 | | - |
21 | | - def tearDown(self): |
22 | | - self.loop.run_until_complete( |
23 | | - self.loop.shutdown_default_executor()) |
24 | | - self.loop.close() |
25 | | - asyncio.set_event_loop(None) |
26 | | - self.loop = None |
27 | | - super().tearDown() |
28 | | - |
29 | | - def test_to_thread(self): |
30 | | - async def main(): |
31 | | - return await asyncio.to_thread(sum, [40, 2]) |
32 | | - |
33 | | - result = self.loop.run_until_complete(main()) |
| 14 | +class ToThreadTests(unittest.IsolatedAsyncioTestCase): |
| 15 | + async def test_to_thread(self): |
| 16 | + result = await asyncio.to_thread(sum, [40, 2]) |
34 | 17 | self.assertEqual(result, 42) |
35 | 18 |
|
36 | | - def test_to_thread_exception(self): |
| 19 | + async def test_to_thread_exception(self): |
37 | 20 | def raise_runtime(): |
38 | 21 | raise RuntimeError("test") |
39 | 22 |
|
40 | | - async def main(): |
41 | | - await asyncio.to_thread(raise_runtime) |
42 | | - |
43 | 23 | with self.assertRaisesRegex(RuntimeError, "test"): |
44 | | - self.loop.run_until_complete(main()) |
| 24 | + await asyncio.to_thread(raise_runtime) |
45 | 25 |
|
46 | | - def test_to_thread_once(self): |
| 26 | + async def test_to_thread_once(self): |
47 | 27 | func = mock.Mock() |
48 | 28 |
|
49 | | - async def main(): |
50 | | - await asyncio.to_thread(func) |
51 | | - |
52 | | - self.loop.run_until_complete(main()) |
| 29 | + await asyncio.to_thread(func) |
53 | 30 | func.assert_called_once() |
54 | 31 |
|
55 | | - def test_to_thread_concurrent(self): |
| 32 | + async def test_to_thread_concurrent(self): |
56 | 33 | func = mock.Mock() |
57 | 34 |
|
58 | | - async def main(): |
59 | | - futs = [] |
60 | | - for _ in range(10): |
61 | | - fut = asyncio.to_thread(func) |
62 | | - futs.append(fut) |
63 | | - await asyncio.gather(*futs) |
| 35 | + futs = [] |
| 36 | + for _ in range(10): |
| 37 | + fut = asyncio.to_thread(func) |
| 38 | + futs.append(fut) |
| 39 | + await asyncio.gather(*futs) |
64 | 40 |
|
65 | | - self.loop.run_until_complete(main()) |
66 | 41 | self.assertEqual(func.call_count, 10) |
67 | 42 |
|
68 | | - def test_to_thread_args_kwargs(self): |
| 43 | + async def test_to_thread_args_kwargs(self): |
69 | 44 | # Unlike run_in_executor(), to_thread() should directly accept kwargs. |
70 | 45 | func = mock.Mock() |
71 | 46 |
|
72 | | - async def main(): |
73 | | - await asyncio.to_thread(func, 'test', something=True) |
| 47 | + await asyncio.to_thread(func, 'test', something=True) |
74 | 48 |
|
75 | | - self.loop.run_until_complete(main()) |
76 | 49 | func.assert_called_once_with('test', something=True) |
77 | 50 |
|
78 | | - def test_to_thread_contextvars(self): |
| 51 | + async def test_to_thread_contextvars(self): |
79 | 52 | test_ctx = ContextVar('test_ctx') |
80 | 53 |
|
81 | 54 | def get_ctx(): |
82 | 55 | return test_ctx.get() |
83 | 56 |
|
84 | | - async def main(): |
85 | | - test_ctx.set('parrot') |
86 | | - return await asyncio.to_thread(get_ctx) |
| 57 | + test_ctx.set('parrot') |
| 58 | + result = await asyncio.to_thread(get_ctx) |
87 | 59 |
|
88 | | - result = self.loop.run_until_complete(main()) |
89 | 60 | self.assertEqual(result, 'parrot') |
90 | 61 |
|
91 | 62 |
|
|
0 commit comments