diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 011466b3..30cfbdce 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -24,6 +24,7 @@ from .futures import * from .locks import * from .protocols import * +from .runners import * from .queues import * from .streams import * from .subprocess import * @@ -36,6 +37,7 @@ futures.__all__ + locks.__all__ + protocols.__all__ + + runners.__all__ + queues.__all__ + streams.__all__ + subprocess.__all__ + diff --git a/asyncio/runners.py b/asyncio/runners.py new file mode 100644 index 00000000..fb70ac57 --- /dev/null +++ b/asyncio/runners.py @@ -0,0 +1,148 @@ +"""asyncio.run() and asyncio.run_forever() functions.""" + +__all__ = ['run', 'run_forever'] + +import inspect +import threading + +from . import coroutines +from . import events + + +def _cleanup(loop): + try: + # `shutdown_asyncgens` was added in Python 3.6; not all + # event loops might support it. + shutdown_asyncgens = loop.shutdown_asyncgens + except AttributeError: + pass + else: + loop.run_until_complete(shutdown_asyncgens()) + finally: + events.set_event_loop(None) + loop.close() + + +def run(main, *, debug=False): + """Run a coroutine. + + This function runs the passed coroutine, taking care of + managing the asyncio event loop and finalizing asynchronous + generators. + + This function must be called from the main thread, and it + cannot be called when another asyncio event loop is running. + + If debug is True, the event loop will be run in debug mode. + + This function should be used as a main entry point for + asyncio programs, and should not be used to call asynchronous + APIs. + + Example: + + async def main(): + await asyncio.sleep(1) + print('hello') + + asyncio.run(main()) + """ + if events._get_running_loop() is not None: + raise RuntimeError( + "asyncio.run() cannot be called from a running event loop") + if not isinstance(threading.current_thread(), threading._MainThread): + raise RuntimeError( + "asyncio.run() must be called from the main thread") + if not coroutines.iscoroutine(main): + raise ValueError("a coroutine was expected, got {!r}".format(main)) + + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + + if debug: + loop.set_debug(True) + + return loop.run_until_complete(main) + finally: + _cleanup(loop) + + +def run_forever(main, *, debug=False): + """Run asyncio loop. + + main must be an asynchronous generator with one yield, separating + program initialization from cleanup logic. + + If debug is True, the event loop will be run in debug mode. + + This function should be used as a main entry point for + asyncio programs, and should not be used to call asynchronous + APIs. + + Example: + + async def main(): + server = await asyncio.start_server(...) + try: + yield # <- Let event loop run forever. + except KeyboardInterrupt: + print('^C received; exiting.') + finally: + server.close() + await server.wait_closed() + + asyncio.run_forever(main()) + """ + if not hasattr(inspect, 'isasyncgen'): + raise NotImplementedError + + if events._get_running_loop() is not None: + raise RuntimeError( + "asyncio.run_forever() cannot be called from a running event loop") + if not isinstance(threading.current_thread(), threading._MainThread): + raise RuntimeError( + "asyncio.run_forever() must be called from the main thread") + if not inspect.isasyncgen(main): + raise ValueError( + "an asynchronous generator was expected, got {!r}".format(main)) + + one_yield_msg = ("asyncio.run_forever() supports only " + "asynchronous generators with one empty yield") + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + if debug: + loop.set_debug(True) + + ret = None + try: + ret = loop.run_until_complete(main.asend(None)) + except StopAsyncIteration as ex: + return + if ret is not None: + raise RuntimeError(one_yield_msg) + + yielded_twice = False + try: + loop.run_forever() + except BaseException as ex: + try: + loop.run_until_complete(main.athrow(ex)) + except StopAsyncIteration as ex: + pass + else: + yielded_twice = True + else: + try: + loop.run_until_complete(main.asend(None)) + except StopAsyncIteration as ex: + pass + else: + yielded_twice = True + + if yielded_twice: + raise RuntimeError(one_yield_msg) + + finally: + _cleanup(loop) diff --git a/runtests.py b/runtests.py index c4074624..8fa2db93 100644 --- a/runtests.py +++ b/runtests.py @@ -112,6 +112,10 @@ def list_dir(prefix, dir): print("Skipping '{0}': need at least Python 3.5".format(modname), file=sys.stderr) continue + if modname == 'test_runner' and (sys.version_info < (3, 6)): + print("Skipping '{0}': need at least Python 3.6".format(modname), + file=sys.stderr) + continue try: loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000..d439353f --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,223 @@ +"""Tests asyncio.run() and asyncio.run_forever().""" + +import asyncio +import unittest +import sys + +from unittest import mock + + +class TestPolicy(asyncio.AbstractEventLoopPolicy): + + def __init__(self, loop_factory): + self.loop_factory = loop_factory + self.loop = None + + def get_event_loop(self): + # shouldn't ever be called by asyncio.run() + # or asyncio.run_forever() + raise RuntimeError + + def new_event_loop(self): + return self.loop_factory() + + def set_event_loop(self, loop): + if loop is not None: + # we want to check if the loop is closed + # in BaseTest.tearDown + self.loop = loop + + +class BaseTest(unittest.TestCase): + + def new_loop(self): + loop = asyncio.BaseEventLoop() + loop._process_events = mock.Mock() + loop._selector = mock.Mock() + loop._selector.select.return_value = () + loop.shutdown_ag_run = False + + async def shutdown_asyncgens(): + loop.shutdown_ag_run = True + loop.shutdown_asyncgens = shutdown_asyncgens + + return loop + + def setUp(self): + super().setUp() + + policy = TestPolicy(self.new_loop) + asyncio.set_event_loop_policy(policy) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + if policy.loop is not None: + self.assertTrue(policy.loop.is_closed()) + self.assertTrue(policy.loop.shutdown_ag_run) + + asyncio.set_event_loop_policy(None) + super().tearDown() + + +class RunTests(BaseTest): + + def test_asyncio_run_return(self): + async def main(): + await asyncio.sleep(0) + return 42 + + self.assertEqual(asyncio.run(main()), 42) + + def test_asyncio_run_raises(self): + async def main(): + await asyncio.sleep(0) + raise ValueError('spam') + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run(main()) + + def test_asyncio_run_only_coro(self): + for o in {1, lambda: None}: + with self.subTest(obj=o), \ + self.assertRaisesRegex(ValueError, + 'a coroutine was expected'): + asyncio.run(o) + + def test_asyncio_run_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + + asyncio.run(main(False)) + asyncio.run(main(True), debug=True) + + def test_asyncio_run_from_running_loop(self): + async def main(): + asyncio.run(main()) + + with self.assertRaisesRegex(RuntimeError, + 'cannot be called from a running'): + asyncio.run(main()) + + +class RunForeverTests(BaseTest): + + def stop_soon(self, *, exc=None): + loop = asyncio.get_event_loop() + + if exc: + def throw(): + raise exc + loop.call_later(0.01, throw) + else: + loop.call_later(0.01, loop.stop) + + def test_asyncio_run_forever_return(self): + async def main(): + if 0: + yield + return + + self.assertIsNone(asyncio.run_forever(main())) + + def test_asyncio_run_forever_non_none_yield(self): + async def main(): + yield 1 + + with self.assertRaisesRegex(RuntimeError, 'one empty yield'): + self.assertIsNone(asyncio.run_forever(main())) + + def test_asyncio_run_forever_try_finally(self): + DONE = 0 + + async def main(): + nonlocal DONE + self.stop_soon() + try: + yield + finally: + DONE += 1 + + asyncio.run_forever(main()) + self.assertEqual(DONE, 1) + + def test_asyncio_run_forever_raises_before_yield(self): + async def main(): + await asyncio.sleep(0) + raise ValueError('spam') + yield + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_raises_after_yield(self): + async def main(): + self.stop_soon() + yield + raise ValueError('spam') + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_two_yields(self): + async def main(): + self.stop_soon() + yield + yield + raise ValueError('spam') + + with self.assertRaisesRegex(RuntimeError, 'one empty yield'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_only_ag(self): + async def coro(): + pass + + for o in {1, lambda: None, coro()}: + with self.subTest(obj=o), \ + self.assertRaisesRegex(ValueError, + 'an asynchronous.*was expected'): + asyncio.run_forever(o) + + def test_asyncio_run_forever_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + if 0: + yield + + asyncio.run_forever(main(False)) + asyncio.run_forever(main(True), debug=True) + + def test_asyncio_run_forever_from_running_loop(self): + async def main(): + asyncio.run_forever(main()) + if 0: + yield + + with self.assertRaisesRegex(RuntimeError, + 'cannot be called from a running'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_base_exception(self): + vi = sys.version_info + if vi[:2] != (3, 6) or vi.releaselevel == 'beta' and vi.serial < 4: + # See http://bugs.python.org/issue28721 for details. + raise unittest.SkipTest( + 'this test requires Python 3.6b4 or greater') + + DONE = 0 + + class MyExc(BaseException): + pass + + async def main(): + nonlocal DONE + self.stop_soon(exc=MyExc) + try: + yield + except MyExc: + DONE += 1 + + asyncio.run_forever(main()) + self.assertEqual(DONE, 1)