From d0d0686520f7ff10e8b27c952c2617122fee5972 Mon Sep 17 00:00:00 2001 From: Victoria Hall Date: Mon, 1 Jul 2024 15:09:41 -0500 Subject: [PATCH 1/8] added optional context param for tasks --- azure_functions_worker/dispatcher.py | 7 ++-- .../create_task_with_context/function.json | 15 ++++++++ .../create_task_with_context/main.py | 22 ++++++++++++ .../create_task_without_context/function.json | 15 ++++++++ .../create_task_without_context/main.py | 17 +++++++++ .../http_functions_stein/function_app.py | 31 ++++++++++++++++ tests/unittests/test_dispatcher.py | 35 ++++++++++++++++++- tests/unittests/test_http_functions.py | 12 +++++++ 8 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/http_functions/create_task_with_context/function.json create mode 100644 tests/unittests/http_functions/create_task_with_context/main.py create mode 100644 tests/unittests/http_functions/create_task_without_context/function.json create mode 100644 tests/unittests/http_functions/create_task_without_context/main.py diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 807985cc6..f96f522cd 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -155,7 +155,8 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression worker_id=self.worker_id))) self._loop.set_task_factory( - lambda loop, coro: ContextEnabledTask(coro, loop=loop)) + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=context)) # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') @@ -1012,8 +1013,8 @@ def emit(self, record: LogRecord) -> None: class ContextEnabledTask(asyncio.Task): AZURE_INVOCATION_ID = '__azure_function_invocation_id__' - def __init__(self, coro, loop): - super().__init__(coro, loop=loop) + def __init__(self, coro, loop, context=None): + super().__init__(coro, loop=loop, context=context) current_task = asyncio.current_task(loop) if current_task is not None: diff --git a/tests/unittests/http_functions/create_task_with_context/function.json b/tests/unittests/http_functions/create_task_with_context/function.json new file mode 100644 index 000000000..5d4d8285f --- /dev/null +++ b/tests/unittests/http_functions/create_task_with_context/function.json @@ -0,0 +1,15 @@ +{ + "scriptFile": "main.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "name": "req" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/tests/unittests/http_functions/create_task_with_context/main.py b/tests/unittests/http_functions/create_task_with_context/main.py new file mode 100644 index 000000000..6a5ad75d9 --- /dev/null +++ b/tests/unittests/http_functions/create_task_with_context/main.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import asyncio +import contextvars + +import azure.functions + +num = contextvars.ContextVar('num') +num.set(5) +ctx = contextvars.copy_context() + + +async def count(name: str): + for i in range(ctx[num]): + await asyncio.sleep(0.5) + return f"Finished {name} in {ctx[num]}" + + +async def main(req: azure.functions.HttpRequest): + count_task = asyncio.create_task(count("Hello World"), context=ctx) + count_val = await count_task + return f'{count_val}' diff --git a/tests/unittests/http_functions/create_task_without_context/function.json b/tests/unittests/http_functions/create_task_without_context/function.json new file mode 100644 index 000000000..5d4d8285f --- /dev/null +++ b/tests/unittests/http_functions/create_task_without_context/function.json @@ -0,0 +1,15 @@ +{ + "scriptFile": "main.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "name": "req" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/tests/unittests/http_functions/create_task_without_context/main.py b/tests/unittests/http_functions/create_task_without_context/main.py new file mode 100644 index 000000000..122247fb3 --- /dev/null +++ b/tests/unittests/http_functions/create_task_without_context/main.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import asyncio + +import azure.functions + + +async def count(name: str, num: int): + for i in range(num): + await asyncio.sleep(0.5) + return f"Finished {name} in {num}" + + +async def main(req: azure.functions.HttpRequest): + count_task = asyncio.create_task(count("Hello World", 5)) + count_val = await count_task + return f'{count_val}' diff --git a/tests/unittests/http_functions/http_functions_stein/function_app.py b/tests/unittests/http_functions/http_functions_stein/function_app.py index 4dd703034..ff7e0d4b4 100644 --- a/tests/unittests/http_functions/http_functions_stein/function_app.py +++ b/tests/unittests/http_functions/http_functions_stein/function_app.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import asyncio +import contextvars import hashlib import json import logging @@ -14,6 +15,22 @@ logger = logging.getLogger("my-function") +num = contextvars.ContextVar('num') +num.set(5) +ctx = contextvars.copy_context() + + +async def count_with_context(name: str): + for i in range(ctx[num]): + await asyncio.sleep(0.5) + return f"Finished {name} in {ctx[num]}" + + +async def count_without_context(name: str, number: int): + for i in range(number): + await asyncio.sleep(0.5) + return f"Finished {name} in {number}" + @app.route(route="return_str") def return_str(req: func.HttpRequest) -> str: @@ -404,3 +421,17 @@ def set_cookie_resp_header_empty( resp.headers.add("Set-Cookie", '') return resp + + +@app.route('create_task_with_context') +async def create_task_with_context(req: func.HttpRequest): + count_task = asyncio.create_task(count_with_context("Hello World"), context=ctx) + count_val = await count_task + return f'{count_val}' + + +@app.route('create_task_without_context') +async def create_task_without_context(req: func.HttpRequest): + count_task = asyncio.create_task(count_without_context("Hello World", 5)) + count_val = await count_task + return f'{count_val}' diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index 84d72e95b..e3990c8ed 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import asyncio import collections as col +import contextvars import os import sys import unittest @@ -16,7 +17,7 @@ PYTHON_ENABLE_INIT_INDEXING, METADATA_PROPERTIES_WORKER_INDEXED, PYTHON_ENABLE_DEBUG_LOGGING) -from azure_functions_worker.dispatcher import Dispatcher +from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask from azure_functions_worker.version import VERSION from tests.utils import testutils from tests.utils.testutils import UNIT_TESTS_ROOT @@ -980,3 +981,35 @@ def test_dispatcher_indexing_in_load_request_with_exception( self.assertEqual( response.function_load_response.result.exception.message, "Exception: Mocked Exception") + + +class TestContextEnabledTask(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_init_with_context(self): + num = contextvars.ContextVar('num') + num.set(5) + ctx = contextvars.copy_context() + exception_raised = False + try: + self.loop.set_task_factory( + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=ctx)) + except TypeError: + exception_raised = True + self.assertFalse(exception_raised) + + async def test_init_without_context(self): + exception_raised = False + try: + self.loop.set_task_factory( + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop)) + except TypeError: + exception_raised = True + self.assertFalse(exception_raised) diff --git a/tests/unittests/test_http_functions.py b/tests/unittests/test_http_functions.py index 109694d71..35550ae07 100644 --- a/tests/unittests/test_http_functions.py +++ b/tests/unittests/test_http_functions.py @@ -442,6 +442,18 @@ def check_log_hijack_current_event_loop(self, host_out: typing.List[str]): # System logs should not exist in host_out self.assertNotIn('parallelly_log_system at disguised_logger', host_out) + def test_create_task_with_context(self): + r = self.webhost.request('GET', 'create_task_with_context') + + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, 'Finished Hello World in 5') + + def test_create_task_without_context(self): + r = self.webhost.request('GET', 'create_task_without_context') + + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, 'Finished Hello World in 5') + class TestHttpFunctionsStein(TestHttpFunctions): From 6dfdabdb713ab64f1f9d0d99ae1c69d2398dc9e9 Mon Sep 17 00:00:00 2001 From: Victoria Hall Date: Mon, 1 Jul 2024 16:54:10 -0500 Subject: [PATCH 2/8] checks for 3.11 or lower --- azure_functions_worker/dispatcher.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index f96f522cd..f375ac6fc 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -154,9 +154,17 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression start_stream=protos.StartStream( worker_id=self.worker_id))) - self._loop.set_task_factory( - lambda loop, coro, context=None: ContextEnabledTask( - coro, loop=loop, context=context)) + # In Python 3.11+, constructing a task has an optional context + # parameter + # https://github.com/Azure/azure-functions-python-worker/issues/1508 + if sys.version_info.minor >= 11: + self._loop.set_task_factory( + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=context)) + else: + self._loop.set_task_factory( + lambda loop, coro: ContextEnabledTask(coro, loop=loop)) + # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') @@ -1014,7 +1022,10 @@ class ContextEnabledTask(asyncio.Task): AZURE_INVOCATION_ID = '__azure_function_invocation_id__' def __init__(self, coro, loop, context=None): - super().__init__(coro, loop=loop, context=context) + if sys.version_info.minor >= 11: + super().__init__(coro, loop=loop, context=context) + else: + super().__init__(coro, loop=loop) current_task = asyncio.current_task(loop) if current_task is not None: From a17ead5c6e751d5dd2615ed50af1a812e8a9b623 Mon Sep 17 00:00:00 2001 From: Victoria Hall Date: Mon, 1 Jul 2024 16:57:08 -0500 Subject: [PATCH 3/8] test fixes --- tests/unittests/test_dispatcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index e3990c8ed..6df91ee47 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -999,7 +999,7 @@ def test_init_with_context(self): try: self.loop.set_task_factory( lambda loop, coro, context=None: ContextEnabledTask( - coro, loop=loop, context=ctx)) + coro, loop=loop, context=context)) except TypeError: exception_raised = True self.assertFalse(exception_raised) @@ -1008,7 +1008,7 @@ async def test_init_without_context(self): exception_raised = False try: self.loop.set_task_factory( - lambda loop, coro, context=None: ContextEnabledTask( + lambda loop, coro: ContextEnabledTask( coro, loop=loop)) except TypeError: exception_raised = True From a741da2a303d9df435aabbb59da5fb43a3f0a44f Mon Sep 17 00:00:00 2001 From: hallvictoria Date: Tue, 2 Jul 2024 09:08:07 -0500 Subject: [PATCH 4/8] lint & skipping tests --- azure_functions_worker/dispatcher.py | 1 - tests/unittests/test_dispatcher.py | 2 +- tests/unittests/test_http_functions.py | 2 ++ 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index f375ac6fc..943655aed 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -165,7 +165,6 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression self._loop.set_task_factory( lambda loop, coro: ContextEnabledTask(coro, loop=loop)) - # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') disable_console_logging() diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index 6df91ee47..9cf20becd 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -999,7 +999,7 @@ def test_init_with_context(self): try: self.loop.set_task_factory( lambda loop, coro, context=None: ContextEnabledTask( - coro, loop=loop, context=context)) + coro, loop=loop, context=ctx)) except TypeError: exception_raised = True self.assertFalse(exception_raised) diff --git a/tests/unittests/test_http_functions.py b/tests/unittests/test_http_functions.py index 35550ae07..e2fd6ad78 100644 --- a/tests/unittests/test_http_functions.py +++ b/tests/unittests/test_http_functions.py @@ -442,6 +442,8 @@ def check_log_hijack_current_event_loop(self, host_out: typing.List[str]): # System logs should not exist in host_out self.assertNotIn('parallelly_log_system at disguised_logger', host_out) + @skipIf(sys.version_info.minor < 11, + "The context param is only available for 3.11+") def test_create_task_with_context(self): r = self.webhost.request('GET', 'create_task_with_context') From 1c502122f2079d1244972483927de98dd32da6c5 Mon Sep 17 00:00:00 2001 From: hallvictoria Date: Tue, 2 Jul 2024 11:01:55 -0500 Subject: [PATCH 5/8] only one check needed --- azure_functions_worker/dispatcher.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 943655aed..3bb7ffd58 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -157,13 +157,9 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression # In Python 3.11+, constructing a task has an optional context # parameter # https://github.com/Azure/azure-functions-python-worker/issues/1508 - if sys.version_info.minor >= 11: - self._loop.set_task_factory( + self._loop.set_task_factory( lambda loop, coro, context=None: ContextEnabledTask( coro, loop=loop, context=context)) - else: - self._loop.set_task_factory( - lambda loop, coro: ContextEnabledTask(coro, loop=loop)) # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') From 8b8c589aa6ab5c377f5eecba77fa5bfb2de40c85 Mon Sep 17 00:00:00 2001 From: hallvictoria Date: Tue, 2 Jul 2024 11:07:37 -0500 Subject: [PATCH 6/8] lint + comments --- azure_functions_worker/dispatcher.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 3bb7ffd58..f02804ac9 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -155,11 +155,11 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression worker_id=self.worker_id))) # In Python 3.11+, constructing a task has an optional context - # parameter + # parameter. Allow for this param to be passed to ContextEnabledTask # https://github.com/Azure/azure-functions-python-worker/issues/1508 self._loop.set_task_factory( - lambda loop, coro, context=None: ContextEnabledTask( - coro, loop=loop, context=context)) + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=context)) # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') @@ -1017,6 +1017,8 @@ class ContextEnabledTask(asyncio.Task): AZURE_INVOCATION_ID = '__azure_function_invocation_id__' def __init__(self, coro, loop, context=None): + # The context param is only available for 3.11+. If + # not, it can't be sent in the init() call. if sys.version_info.minor >= 11: super().__init__(coro, loop=loop, context=context) else: From d6c4b7cee698588a07672614dec530f47a400b94 Mon Sep 17 00:00:00 2001 From: Victoria Hall Date: Tue, 13 Aug 2024 10:22:36 -0500 Subject: [PATCH 7/8] better tests --- .../create_task_with_context/main.py | 27 ++++++++++++---- .../create_task_without_context/main.py | 3 ++ .../http_functions_stein/function_app.py | 32 +++++++++++++++---- tests/unittests/test_dispatcher.py | 6 +++- tests/unittests/test_http_functions.py | 3 +- 5 files changed, 55 insertions(+), 16 deletions(-) diff --git a/tests/unittests/http_functions/create_task_with_context/main.py b/tests/unittests/http_functions/create_task_with_context/main.py index 6a5ad75d9..f603acd1b 100644 --- a/tests/unittests/http_functions/create_task_with_context/main.py +++ b/tests/unittests/http_functions/create_task_with_context/main.py @@ -6,17 +6,30 @@ import azure.functions num = contextvars.ContextVar('num') -num.set(5) -ctx = contextvars.copy_context() async def count(name: str): - for i in range(ctx[num]): + # The number of times the loop is executed + # depends on the val set in context + val = num.get() + for i in range(val): await asyncio.sleep(0.5) - return f"Finished {name} in {ctx[num]}" + return f"Finished {name} in {val}" async def main(req: azure.functions.HttpRequest): - count_task = asyncio.create_task(count("Hello World"), context=ctx) - count_val = await count_task - return f'{count_val}' + # Create first task with context num = 5 + num.set(5) + first_ctx = contextvars.copy_context() + first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx) + + # Create second task with context num = 10 + num.set(10) + second_ctx = contextvars.copy_context() + second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx) + + # Execute tasks + first_count_val = await first_count_task + second_count_val = await second_count_task + + return f'{first_count_val + " | " + second_count_val}' diff --git a/tests/unittests/http_functions/create_task_without_context/main.py b/tests/unittests/http_functions/create_task_without_context/main.py index 122247fb3..c7ee21f7b 100644 --- a/tests/unittests/http_functions/create_task_without_context/main.py +++ b/tests/unittests/http_functions/create_task_without_context/main.py @@ -6,12 +6,15 @@ async def count(name: str, num: int): + # The number of times the loop executes is decided by a + # user-defined param for i in range(num): await asyncio.sleep(0.5) return f"Finished {name} in {num}" async def main(req: azure.functions.HttpRequest): + # No context is being sent into asyncio.create_task count_task = asyncio.create_task(count("Hello World", 5)) count_val = await count_task return f'{count_val}' diff --git a/tests/unittests/http_functions/http_functions_stein/function_app.py b/tests/unittests/http_functions/http_functions_stein/function_app.py index ff7e0d4b4..112813de9 100644 --- a/tests/unittests/http_functions/http_functions_stein/function_app.py +++ b/tests/unittests/http_functions/http_functions_stein/function_app.py @@ -16,17 +16,20 @@ logger = logging.getLogger("my-function") num = contextvars.ContextVar('num') -num.set(5) -ctx = contextvars.copy_context() async def count_with_context(name: str): - for i in range(ctx[num]): + # The number of times the loop is executed + # depends on the val set in context + val = num.get() + for i in range(val): await asyncio.sleep(0.5) - return f"Finished {name} in {ctx[num]}" + return f"Finished {name} in {val}" async def count_without_context(name: str, number: int): + # The number of times the loop executes is decided by a + # user-defined param for i in range(number): await asyncio.sleep(0.5) return f"Finished {name} in {number}" @@ -425,13 +428,28 @@ def set_cookie_resp_header_empty( @app.route('create_task_with_context') async def create_task_with_context(req: func.HttpRequest): - count_task = asyncio.create_task(count_with_context("Hello World"), context=ctx) - count_val = await count_task - return f'{count_val}' + # Create first task with context num = 5 + num.set(5) + first_ctx = contextvars.copy_context() + first_count_task = asyncio.create_task( + count_with_context("Hello World"), context=first_ctx) + + # Create second task with context num = 10 + num.set(10) + second_ctx = contextvars.copy_context() + second_count_task = asyncio.create_task( + count_with_context("Hello World"), context=second_ctx) + + # Execute tasks + first_count_val = await first_count_task + second_count_val = await second_count_task + + return f'{first_count_val + " | " + second_count_val}' @app.route('create_task_without_context') async def create_task_without_context(req: func.HttpRequest): + # No context is being sent into asyncio.create_task count_task = asyncio.create_task(count_without_context("Hello World", 5)) count_val = await count_task return f'{count_val}' diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index e1aa589b2..32eca34bf 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -22,7 +22,7 @@ PYTHON_THREADPOOL_THREAD_COUNT_MAX_37, PYTHON_THREADPOOL_THREAD_COUNT_MIN, ) -from azure_functions_worker.dispatcher import Dispatcher +from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask from azure_functions_worker.version import VERSION SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro", @@ -1001,6 +1001,8 @@ def tearDown(self): self.loop.close() def test_init_with_context(self): + # Since ContextEnabledTask accepts the context param, + # no errors will be thrown here num = contextvars.ContextVar('num') num.set(5) ctx = contextvars.copy_context() @@ -1014,6 +1016,8 @@ def test_init_with_context(self): self.assertFalse(exception_raised) async def test_init_without_context(self): + # If the context param is not defined, + # no errors will be thrown for backwards compatibility exception_raised = False try: self.loop.set_task_factory( diff --git a/tests/unittests/test_http_functions.py b/tests/unittests/test_http_functions.py index fe1f5b8aa..03e0b5806 100644 --- a/tests/unittests/test_http_functions.py +++ b/tests/unittests/test_http_functions.py @@ -452,7 +452,8 @@ def test_create_task_with_context(self): r = self.webhost.request('GET', 'create_task_with_context') self.assertEqual(r.status_code, 200) - self.assertEqual(r.text, 'Finished Hello World in 5') + self.assertEqual(r.text, 'Finished Hello World in 5' + ' | Finished Hello World in 10') def test_create_task_without_context(self): r = self.webhost.request('GET', 'create_task_without_context') From dc616e2ca0cf4f4a86314291657f848bc0762359 Mon Sep 17 00:00:00 2001 From: Victoria Hall Date: Wed, 14 Aug 2024 10:42:48 -0500 Subject: [PATCH 8/8] removed comment --- azure_functions_worker/dispatcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 0ea9b3b95..820c328ff 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -173,7 +173,6 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression # In Python 3.11+, constructing a task has an optional context # parameter. Allow for this param to be passed to ContextEnabledTask - # https://github.com/Azure/azure-functions-python-worker/issues/1508 self._loop.set_task_factory( lambda loop, coro, context=None: ContextEnabledTask( coro, loop=loop, context=context))