Skip to content

fix: added optional context param for tasks #1523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions azure_functions_worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,11 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression
start_stream=protos.StartStream(
worker_id=self.worker_id)))

# In Python 3.11+, constructing a task has an optional context
# parameter. Allow for this param to be passed to ContextEnabledTask
self._loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(coro, loop=loop))
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=context))

# Detach console logging before enabling GRPC channel logging
logger.info('Detaching console logging.')
Expand Down Expand Up @@ -1068,8 +1071,13 @@ def emit(self, record: LogRecord) -> None:
class ContextEnabledTask(asyncio.Task):
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'

def __init__(self, coro, loop):
super().__init__(coro, loop=loop)
def __init__(self, coro, loop, context=None):
# The context param is only available for 3.11+. If
# not, it can't be sent in the init() call.
if sys.version_info.minor >= 11:
super().__init__(coro, loop=loop, context=context)
else:
super().__init__(coro, loop=loop)

current_task = asyncio.current_task(loop)
if current_task is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}
35 changes: 35 additions & 0 deletions tests/unittests/http_functions/create_task_with_context/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars

import azure.functions

num = contextvars.ContextVar('num')


async def count(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"


async def main(req: azure.functions.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx)

# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx)

# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task

return f'{first_count_val + " | " + second_count_val}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio

import azure.functions


async def count(name: str, num: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(num):
await asyncio.sleep(0.5)
return f"Finished {name} in {num}"


async def main(req: azure.functions.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count("Hello World", 5))
count_val = await count_task
return f'{count_val}'
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars
import hashlib
import json
import logging
Expand All @@ -14,6 +15,25 @@

logger = logging.getLogger("my-function")

num = contextvars.ContextVar('num')


async def count_with_context(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"


async def count_without_context(name: str, number: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(number):
await asyncio.sleep(0.5)
return f"Finished {name} in {number}"


@app.route(route="return_str")
def return_str(req: func.HttpRequest) -> str:
Expand Down Expand Up @@ -404,3 +424,32 @@ def set_cookie_resp_header_empty(
resp.headers.add("Set-Cookie", '')

return resp


@app.route('create_task_with_context')
async def create_task_with_context(req: func.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(
count_with_context("Hello World"), context=first_ctx)

# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(
count_with_context("Hello World"), context=second_ctx)

# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task

return f'{first_count_val + " | " + second_count_val}'


@app.route('create_task_without_context')
async def create_task_without_context(req: func.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count_without_context("Hello World", 5))
count_val = await count_task
return f'{count_val}'
39 changes: 38 additions & 1 deletion tests/unittests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import asyncio
import collections as col
import contextvars
import os
import sys
import unittest
Expand All @@ -21,7 +22,7 @@
PYTHON_THREADPOOL_THREAD_COUNT_MAX_37,
PYTHON_THREADPOOL_THREAD_COUNT_MIN,
)
from azure_functions_worker.dispatcher import Dispatcher
from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask
from azure_functions_worker.version import VERSION

SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro",
Expand Down Expand Up @@ -989,3 +990,39 @@ def test_dispatcher_indexing_in_load_request_with_exception(
self.assertEqual(
response.function_load_response.result.exception.message,
"Exception: Mocked Exception")


class TestContextEnabledTask(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

def tearDown(self):
self.loop.close()

def test_init_with_context(self):
# Since ContextEnabledTask accepts the context param,
# no errors will be thrown here
num = contextvars.ContextVar('num')
num.set(5)
ctx = contextvars.copy_context()
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=ctx))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)

async def test_init_without_context(self):
# If the context param is not defined,
# no errors will be thrown for backwards compatibility
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(
coro, loop=loop))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)
15 changes: 15 additions & 0 deletions tests/unittests/test_http_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,21 @@ def check_log_hijack_current_event_loop(self, host_out: typing.List[str]):
# System logs should not exist in host_out
self.assertNotIn('parallelly_log_system at disguised_logger', host_out)

@skipIf(sys.version_info.minor < 11,
"The context param is only available for 3.11+")
def test_create_task_with_context(self):
r = self.webhost.request('GET', 'create_task_with_context')

self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5'
' | Finished Hello World in 10')

def test_create_task_without_context(self):
r = self.webhost.request('GET', 'create_task_without_context')

self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5')


class TestHttpFunctionsStein(TestHttpFunctions):

Expand Down
Loading