diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index 7d4be3273..ae40ce398 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -156,6 +156,7 @@ def from_incoming_proto( pytype: typing.Optional[type], trigger_metadata: typing.Optional[typing.Dict[str, protos.TypedData]], shmem_mgr: SharedMemoryManager, + function_name: str, is_deferred_binding: typing.Optional[bool] = False) -> typing.Any: binding = get_binding(binding, is_deferred_binding) if trigger_metadata: @@ -184,7 +185,8 @@ def from_incoming_proto( pb=pb, pytype=pytype, datum=datum, - metadata=metadata) + metadata=metadata, + function_name=function_name) return binding.decode(datum, trigger_metadata=metadata) except NotImplementedError: # Binding does not support the data. @@ -281,29 +283,40 @@ def deferred_bindings_decode(binding: typing.Any, pb: protos.ParameterBinding, *, pytype: typing.Optional[type], datum: typing.Any, - metadata: typing.Any): + metadata: typing.Any, + function_name: str): """ This cache holds deferred binding types (ie. BlobClient, ContainerClient) That have already been created, so that the worker can reuse the Previously created type without creating a new one. + For async types, the function_name is needed as a key to differentiate. + This prevents a known SDK issue where reusing a client across functions + can lose the session context and cause an error. + + The cache key is based on: param name, type, resource, function_name + If cache is empty or key doesn't exist, deferred_binding_type is None """ global deferred_bindings_cache if deferred_bindings_cache.get((pb.name, pytype, - datum.value.content), None) is not None: + datum.value.content, + function_name), None) is not None: return deferred_bindings_cache.get((pb.name, pytype, - datum.value.content)) + datum.value.content, + function_name)) else: deferred_binding_type = binding.decode(datum, trigger_metadata=metadata, pytype=pytype) + deferred_bindings_cache[(pb.name, pytype, - datum.value.content)] = deferred_binding_type + datum.value.content, + function_name)] = deferred_binding_type return deferred_binding_type diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 820c328ff..a9849b289 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -632,6 +632,8 @@ async def _handle__invocation_request(self, request): trigger_metadata=trigger_metadata, pytype=pb_type_info.pytype, shmem_mgr=self._shmem_mgr, + function_name=self._functions.get_function( + function_id).name, is_deferred_binding=pb_type_info.deferred_bindings_enabled) if http_v2_enabled: diff --git a/tests/extension_tests/deferred_bindings_tests/deferred_bindings_blob_functions/function_app.py b/tests/extension_tests/deferred_bindings_tests/deferred_bindings_blob_functions/function_app.py index 1a8062aa0..075d8a78a 100644 --- a/tests/extension_tests/deferred_bindings_tests/deferred_bindings_blob_functions/function_app.py +++ b/tests/extension_tests/deferred_bindings_tests/deferred_bindings_blob_functions/function_app.py @@ -256,7 +256,32 @@ def put_blob_bytes(req: func.HttpRequest, file: func.Out[bytes]) -> str: @app.route(route="blob_cache") def blob_cache(req: func.HttpRequest, cachedClient: blob.BlobClient) -> str: - return cachedClient.download_blob(encoding='utf-8').readall() + return func.HttpResponse(repr(cachedClient)) + + +@app.function_name(name="blob_cache2") +@app.blob_input(arg_name="cachedClient", + path="python-worker-tests/test-blobclient-triggered.txt", + connection="AzureWebJobsStorage") +@app.route(route="blob_cache2") +def blob_cache2(req: func.HttpRequest, + cachedClient: blob.BlobClient) -> func.HttpResponse: + return func.HttpResponse(repr(cachedClient)) + + +@app.function_name(name="blob_cache3") +@app.blob_input(arg_name="cachedClient", + path="python-worker-tests/test-blobclient-triggered.txt", + connection="AzureWebJobsStorage") +@app.blob_input(arg_name="cachedClient2", + path="python-worker-tests/test-blobclient-triggered.txt", + connection="AzureWebJobsStorage") +@app.route(route="blob_cache3") +def blob_cache3(req: func.HttpRequest, + cachedClient: blob.BlobClient, + cachedClient2: blob.BlobClient) -> func.HttpResponse: + return func.HttpResponse("Client 1: " + repr(cachedClient) + + " | Client 2: " + repr(cachedClient2)) @app.function_name(name="invalid_connection_info") @@ -265,5 +290,5 @@ def blob_cache(req: func.HttpRequest, connection="NotARealConnectionString") @app.route(route="invalid_connection_info") def invalid_connection_info(req: func.HttpRequest, - client: blob.BlobClient) -> str: - return client.download_blob(encoding='utf-8').readall() + client: blob.BlobClient) -> func.HttpResponse: + return func.HttpResponse(repr(client)) diff --git a/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings.py b/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings.py index 63e952682..1899f9e75 100644 --- a/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings.py +++ b/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings.py @@ -153,7 +153,8 @@ def test_deferred_bindings_enabled_decode(self): datum = datumdef.Datum(value=sample_mbd, type='model_binding_data') obj = meta.deferred_bindings_decode(binding=binding, pb=pb, - pytype=BlobClient, datum=datum, metadata={}) + pytype=BlobClient, datum=datum, metadata={}, + function_name="test_function") self.assertIsNotNone(obj) diff --git a/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings_blob_functions.py b/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings_blob_functions.py index 0a90f2075..ed441a077 100644 --- a/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings_blob_functions.py +++ b/tests/extension_tests/deferred_bindings_tests/test_deferred_bindings_blob_functions.py @@ -6,8 +6,6 @@ from tests.utils import testutils -from azure_functions_worker.bindings import meta - @unittest.skipIf(sys.version_info.minor <= 8, "The base extension" "is only supported for 3.9+.") @@ -174,16 +172,58 @@ def test_type_undefined(self): self.assertEqual(r.text, 'test-data') def test_caching(self): - # Cache is empty at the start - self.assertEqual(meta.deferred_bindings_cache, {}) + ''' + The cache returns the same type based on resource and function name. + Two different functions with clients that access the same resource + will have two different clients. This tests that the same client + is returned for each invocation and that the clients are different + between the two functions. + ''' + r = self.webhost.request('GET', 'blob_cache') + r2 = self.webhost.request('GET', 'blob_cache2') self.assertEqual(r.status_code, 200) + self.assertEqual(r2.status_code, 200) + client = r.text + client2 = r2.text + self.assertNotEqual(client, client2) r = self.webhost.request('GET', 'blob_cache') + r2 = self.webhost.request('GET', 'blob_cache2') self.assertEqual(r.status_code, 200) + self.assertEqual(r2.status_code, 200) + self.assertEqual(r.text, client) + self.assertEqual(r2.text, client2) + self.assertNotEqual(r.text, r2.text) r = self.webhost.request('GET', 'blob_cache') - self.assertEqual(r.status_code, 200) + r2 = self.webhost.request('GET', 'blob_cache2') + self.assertEqual(r.status_code, 200) + self.assertEqual(r2.status_code, 200) + self.assertEqual(r.text, client) + self.assertEqual(r2.text, client2) + self.assertNotEqual(r.text, r2.text) + + def test_caching_same_resource(self): + ''' + The cache returns the same type based on param name. + One functions with two clients that access the same resource + will have two different clients. This tests that the same clients + are returned for each invocation and that the clients are different + between the two bindings. + ''' + + r = self.webhost.request('GET', 'blob_cache3') + self.assertEqual(r.status_code, 200) + clients = r.text.split(" | ") + self.assertNotEqual(clients[0], clients[1]) + + r2 = self.webhost.request('GET', 'blob_cache3') + self.assertEqual(r2.status_code, 200) + clients_second_call = r2.text.split(" | ") + self.assertEqual(clients[0], clients_second_call[0]) + self.assertEqual(clients[1], clients_second_call[1]) + self.assertNotEqual(clients_second_call[0], clients_second_call[1]) def test_failed_client_creation(self): r = self.webhost.request('GET', 'invalid_connection_info')