Skip to content

Commit eda9826

Browse files
hallvictoriaVictoria Hall
and
Victoria Hall
authored
fix: unique clients per function (#1490)
* unique cache for fx name * lint * added test for same resource different binding --------- Co-authored-by: Victoria Hall <[email protected]>
1 parent f4c9c2d commit eda9826

File tree

5 files changed

+95
-14
lines changed

5 files changed

+95
-14
lines changed

azure_functions_worker/bindings/meta.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def from_incoming_proto(
156156
pytype: typing.Optional[type],
157157
trigger_metadata: typing.Optional[typing.Dict[str, protos.TypedData]],
158158
shmem_mgr: SharedMemoryManager,
159+
function_name: str,
159160
is_deferred_binding: typing.Optional[bool] = False) -> typing.Any:
160161
binding = get_binding(binding, is_deferred_binding)
161162
if trigger_metadata:
@@ -184,7 +185,8 @@ def from_incoming_proto(
184185
pb=pb,
185186
pytype=pytype,
186187
datum=datum,
187-
metadata=metadata)
188+
metadata=metadata,
189+
function_name=function_name)
188190
return binding.decode(datum, trigger_metadata=metadata)
189191
except NotImplementedError:
190192
# Binding does not support the data.
@@ -281,29 +283,40 @@ def deferred_bindings_decode(binding: typing.Any,
281283
pb: protos.ParameterBinding, *,
282284
pytype: typing.Optional[type],
283285
datum: typing.Any,
284-
metadata: typing.Any):
286+
metadata: typing.Any,
287+
function_name: str):
285288
"""
286289
This cache holds deferred binding types (ie. BlobClient, ContainerClient)
287290
That have already been created, so that the worker can reuse the
288291
Previously created type without creating a new one.
289292
293+
For async types, the function_name is needed as a key to differentiate.
294+
This prevents a known SDK issue where reusing a client across functions
295+
can lose the session context and cause an error.
296+
297+
The cache key is based on: param name, type, resource, function_name
298+
290299
If cache is empty or key doesn't exist, deferred_binding_type is None
291300
"""
292301
global deferred_bindings_cache
293302

294303
if deferred_bindings_cache.get((pb.name,
295304
pytype,
296-
datum.value.content), None) is not None:
305+
datum.value.content,
306+
function_name), None) is not None:
297307
return deferred_bindings_cache.get((pb.name,
298308
pytype,
299-
datum.value.content))
309+
datum.value.content,
310+
function_name))
300311
else:
301312
deferred_binding_type = binding.decode(datum,
302313
trigger_metadata=metadata,
303314
pytype=pytype)
315+
304316
deferred_bindings_cache[(pb.name,
305317
pytype,
306-
datum.value.content)] = deferred_binding_type
318+
datum.value.content,
319+
function_name)] = deferred_binding_type
307320
return deferred_binding_type
308321

309322

azure_functions_worker/dispatcher.py

+2
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,8 @@ async def _handle__invocation_request(self, request):
632632
trigger_metadata=trigger_metadata,
633633
pytype=pb_type_info.pytype,
634634
shmem_mgr=self._shmem_mgr,
635+
function_name=self._functions.get_function(
636+
function_id).name,
635637
is_deferred_binding=pb_type_info.deferred_bindings_enabled)
636638

637639
if http_v2_enabled:

tests/extension_tests/deferred_bindings_tests/deferred_bindings_blob_functions/function_app.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,32 @@ def put_blob_bytes(req: func.HttpRequest, file: func.Out[bytes]) -> str:
256256
@app.route(route="blob_cache")
257257
def blob_cache(req: func.HttpRequest,
258258
cachedClient: blob.BlobClient) -> str:
259-
return cachedClient.download_blob(encoding='utf-8').readall()
259+
return func.HttpResponse(repr(cachedClient))
260+
261+
262+
@app.function_name(name="blob_cache2")
263+
@app.blob_input(arg_name="cachedClient",
264+
path="python-worker-tests/test-blobclient-triggered.txt",
265+
connection="AzureWebJobsStorage")
266+
@app.route(route="blob_cache2")
267+
def blob_cache2(req: func.HttpRequest,
268+
cachedClient: blob.BlobClient) -> func.HttpResponse:
269+
return func.HttpResponse(repr(cachedClient))
270+
271+
272+
@app.function_name(name="blob_cache3")
273+
@app.blob_input(arg_name="cachedClient",
274+
path="python-worker-tests/test-blobclient-triggered.txt",
275+
connection="AzureWebJobsStorage")
276+
@app.blob_input(arg_name="cachedClient2",
277+
path="python-worker-tests/test-blobclient-triggered.txt",
278+
connection="AzureWebJobsStorage")
279+
@app.route(route="blob_cache3")
280+
def blob_cache3(req: func.HttpRequest,
281+
cachedClient: blob.BlobClient,
282+
cachedClient2: blob.BlobClient) -> func.HttpResponse:
283+
return func.HttpResponse("Client 1: " + repr(cachedClient)
284+
+ " | Client 2: " + repr(cachedClient2))
260285

261286

262287
@app.function_name(name="invalid_connection_info")
@@ -265,5 +290,5 @@ def blob_cache(req: func.HttpRequest,
265290
connection="NotARealConnectionString")
266291
@app.route(route="invalid_connection_info")
267292
def invalid_connection_info(req: func.HttpRequest,
268-
client: blob.BlobClient) -> str:
269-
return client.download_blob(encoding='utf-8').readall()
293+
client: blob.BlobClient) -> func.HttpResponse:
294+
return func.HttpResponse(repr(client))

tests/extension_tests/deferred_bindings_tests/test_deferred_bindings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def test_deferred_bindings_enabled_decode(self):
153153
datum = datumdef.Datum(value=sample_mbd, type='model_binding_data')
154154

155155
obj = meta.deferred_bindings_decode(binding=binding, pb=pb,
156-
pytype=BlobClient, datum=datum, metadata={})
156+
pytype=BlobClient, datum=datum, metadata={},
157+
function_name="test_function")
157158

158159
self.assertIsNotNone(obj)
159160

tests/extension_tests/deferred_bindings_tests/test_deferred_bindings_blob_functions.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from tests.utils import testutils
88

9-
from azure_functions_worker.bindings import meta
10-
119

1210
@unittest.skipIf(sys.version_info.minor <= 8, "The base extension"
1311
"is only supported for 3.9+.")
@@ -174,16 +172,58 @@ def test_type_undefined(self):
174172
self.assertEqual(r.text, 'test-data')
175173

176174
def test_caching(self):
177-
# Cache is empty at the start
178-
self.assertEqual(meta.deferred_bindings_cache, {})
175+
'''
176+
The cache returns the same type based on resource and function name.
177+
Two different functions with clients that access the same resource
178+
will have two different clients. This tests that the same client
179+
is returned for each invocation and that the clients are different
180+
between the two functions.
181+
'''
182+
179183
r = self.webhost.request('GET', 'blob_cache')
184+
r2 = self.webhost.request('GET', 'blob_cache2')
180185
self.assertEqual(r.status_code, 200)
186+
self.assertEqual(r2.status_code, 200)
187+
client = r.text
188+
client2 = r2.text
189+
self.assertNotEqual(client, client2)
181190

182191
r = self.webhost.request('GET', 'blob_cache')
192+
r2 = self.webhost.request('GET', 'blob_cache2')
183193
self.assertEqual(r.status_code, 200)
194+
self.assertEqual(r2.status_code, 200)
195+
self.assertEqual(r.text, client)
196+
self.assertEqual(r2.text, client2)
197+
self.assertNotEqual(r.text, r2.text)
184198

185199
r = self.webhost.request('GET', 'blob_cache')
186-
self.assertEqual(r.status_code, 200)
200+
r2 = self.webhost.request('GET', 'blob_cache2')
201+
self.assertEqual(r.status_code, 200)
202+
self.assertEqual(r2.status_code, 200)
203+
self.assertEqual(r.text, client)
204+
self.assertEqual(r2.text, client2)
205+
self.assertNotEqual(r.text, r2.text)
206+
207+
def test_caching_same_resource(self):
208+
'''
209+
The cache returns the same type based on param name.
210+
One functions with two clients that access the same resource
211+
will have two different clients. This tests that the same clients
212+
are returned for each invocation and that the clients are different
213+
between the two bindings.
214+
'''
215+
216+
r = self.webhost.request('GET', 'blob_cache3')
217+
self.assertEqual(r.status_code, 200)
218+
clients = r.text.split(" | ")
219+
self.assertNotEqual(clients[0], clients[1])
220+
221+
r2 = self.webhost.request('GET', 'blob_cache3')
222+
self.assertEqual(r2.status_code, 200)
223+
clients_second_call = r2.text.split(" | ")
224+
self.assertEqual(clients[0], clients_second_call[0])
225+
self.assertEqual(clients[1], clients_second_call[1])
226+
self.assertNotEqual(clients_second_call[0], clients_second_call[1])
187227

188228
def test_failed_client_creation(self):
189229
r = self.webhost.request('GET', 'invalid_connection_info')

0 commit comments

Comments
 (0)