|
5 | 5 |
|
6 | 6 | from contextvars import ContextVar
|
7 | 7 | from unittest import mock
|
8 |
| -from test.test_asyncio import utils as test_utils |
9 | 8 |
|
10 | 9 |
|
11 | 10 | def tearDownModule():
|
12 | 11 | asyncio.set_event_loop_policy(None)
|
13 | 12 |
|
14 | 13 |
|
15 |
| -class ToThreadTests(test_utils.TestCase): |
16 |
| - def setUp(self): |
17 |
| - super().setUp() |
18 |
| - self.loop = asyncio.new_event_loop() |
19 |
| - asyncio.set_event_loop(self.loop) |
20 |
| - |
21 |
| - def tearDown(self): |
22 |
| - self.loop.run_until_complete( |
23 |
| - self.loop.shutdown_default_executor()) |
24 |
| - self.loop.close() |
25 |
| - asyncio.set_event_loop(None) |
26 |
| - self.loop = None |
27 |
| - super().tearDown() |
28 |
| - |
29 |
| - def test_to_thread(self): |
30 |
| - async def main(): |
31 |
| - return await asyncio.to_thread(sum, [40, 2]) |
32 |
| - |
33 |
| - result = self.loop.run_until_complete(main()) |
| 14 | +class ToThreadTests(unittest.IsolatedAsyncioTestCase): |
| 15 | + async def test_to_thread(self): |
| 16 | + result = await asyncio.to_thread(sum, [40, 2]) |
34 | 17 | self.assertEqual(result, 42)
|
35 | 18 |
|
36 |
| - def test_to_thread_exception(self): |
| 19 | + async def test_to_thread_exception(self): |
37 | 20 | def raise_runtime():
|
38 | 21 | raise RuntimeError("test")
|
39 | 22 |
|
40 |
| - async def main(): |
41 |
| - await asyncio.to_thread(raise_runtime) |
42 |
| - |
43 | 23 | with self.assertRaisesRegex(RuntimeError, "test"):
|
44 |
| - self.loop.run_until_complete(main()) |
| 24 | + await asyncio.to_thread(raise_runtime) |
45 | 25 |
|
46 |
| - def test_to_thread_once(self): |
| 26 | + async def test_to_thread_once(self): |
47 | 27 | func = mock.Mock()
|
48 | 28 |
|
49 |
| - async def main(): |
50 |
| - await asyncio.to_thread(func) |
51 |
| - |
52 |
| - self.loop.run_until_complete(main()) |
| 29 | + await asyncio.to_thread(func) |
53 | 30 | func.assert_called_once()
|
54 | 31 |
|
55 |
| - def test_to_thread_concurrent(self): |
| 32 | + async def test_to_thread_concurrent(self): |
56 | 33 | func = mock.Mock()
|
57 | 34 |
|
58 |
| - async def main(): |
59 |
| - futs = [] |
60 |
| - for _ in range(10): |
61 |
| - fut = asyncio.to_thread(func) |
62 |
| - futs.append(fut) |
63 |
| - await asyncio.gather(*futs) |
| 35 | + futs = [] |
| 36 | + for _ in range(10): |
| 37 | + fut = asyncio.to_thread(func) |
| 38 | + futs.append(fut) |
| 39 | + await asyncio.gather(*futs) |
64 | 40 |
|
65 |
| - self.loop.run_until_complete(main()) |
66 | 41 | self.assertEqual(func.call_count, 10)
|
67 | 42 |
|
68 |
| - def test_to_thread_args_kwargs(self): |
| 43 | + async def test_to_thread_args_kwargs(self): |
69 | 44 | # Unlike run_in_executor(), to_thread() should directly accept kwargs.
|
70 | 45 | func = mock.Mock()
|
71 | 46 |
|
72 |
| - async def main(): |
73 |
| - await asyncio.to_thread(func, 'test', something=True) |
| 47 | + await asyncio.to_thread(func, 'test', something=True) |
74 | 48 |
|
75 |
| - self.loop.run_until_complete(main()) |
76 | 49 | func.assert_called_once_with('test', something=True)
|
77 | 50 |
|
78 |
| - def test_to_thread_contextvars(self): |
| 51 | + async def test_to_thread_contextvars(self): |
79 | 52 | test_ctx = ContextVar('test_ctx')
|
80 | 53 |
|
81 | 54 | def get_ctx():
|
82 | 55 | return test_ctx.get()
|
83 | 56 |
|
84 |
| - async def main(): |
85 |
| - test_ctx.set('parrot') |
86 |
| - return await asyncio.to_thread(get_ctx) |
| 57 | + test_ctx.set('parrot') |
| 58 | + result = await asyncio.to_thread(get_ctx) |
87 | 59 |
|
88 |
| - result = self.loop.run_until_complete(main()) |
89 | 60 | self.assertEqual(result, 'parrot')
|
90 | 61 |
|
91 | 62 |
|
|
0 commit comments