Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)

use_thread_for_polling_next = (
os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12"
)

with create_log_sink() as log_sink:
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)

Expand All @@ -479,7 +483,9 @@ def run(app_root, handler, lambda_runtime_api_addr):
sys.exit(1)

while True:
event_request = lambda_runtime_client.wait_next_invocation()
event_request = lambda_runtime_client.wait_next_invocation(
use_thread_for_polling_next
)

_GLOBAL_AWS_REQUEST_ID = event_request.invoke_id

Expand Down
24 changes: 19 additions & 5 deletions awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import sys
from concurrent.futures import ThreadPoolExecutor
from awslambdaric import __version__


Expand Down Expand Up @@ -68,10 +67,25 @@ def post_init_error(self, error_response_data):
if response.code != http.HTTPStatus.ACCEPTED:
raise LambdaRuntimeClientError(endpoint, response.code, response_body)

def wait_next_invocation(self):
with ThreadPoolExecutor() as e:
fut = e.submit(runtime_client.next)
response_body, headers = fut.result()
def wait_next_invocation(self, use_thread_for_polling_next=False):
# Calling runtime_client.next() from a separate thread unblocks the main thread,
# which can then process signals.
if use_thread_for_polling_next:
from concurrent.futures import ThreadPoolExecutor
from .lambda_runtime_exception import FaultException

try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(runtime_client.next)
response_body, headers = future.result()
except Exception as e:
raise FaultException(
FaultException.LAMBDA_RUNTIME_CLIENT_ERROR,
"LAMBDA_RUNTIME Failed to get next invocation: {}".format(str(e)),
None,
)
else:
response_body, headers = runtime_client.next()
return InvocationRequest(
invoke_id=headers.get("Lambda-Runtime-Aws-Request-Id"),
x_amzn_trace_id=headers.get("Lambda-Runtime-Trace-Id"),
Expand Down
1 change: 1 addition & 0 deletions awslambdaric/lambda_runtime_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FaultException(Exception):
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"

def __init__(self, exception_type, msg, trace=None):
self.msg = msg
Expand Down
17 changes: 16 additions & 1 deletion tests/test_lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,22 @@ def test_wait_next_invocation(self, mock_runtime_client):
mock_runtime_client.next.return_value = response_body, headears
runtime_client = LambdaRuntimeClient("localhost:1234")

event_request = runtime_client.wait_next_invocation()
use_thread_for_polling_next = True
event_request = runtime_client.wait_next_invocation(use_thread_for_polling_next)

self.assertIsNotNone(event_request)
self.assertEqual(event_request.invoke_id, "RID1234")
self.assertEqual(event_request.x_amzn_trace_id, "TID1234")
self.assertEqual(event_request.invoked_function_arn, "FARN1234")
self.assertEqual(event_request.deadline_time_in_ms, 12)
self.assertEqual(event_request.client_context, "client_context")
self.assertEqual(event_request.cognito_identity, "cognito_identity")
self.assertEqual(event_request.content_type, "application/json")
self.assertEqual(event_request.event_body, response_body)

# Using ThreadPoolExecutor to polling next()
use_thread_for_polling_next = False
event_request = runtime_client.wait_next_invocation(use_thread_for_polling_next)

self.assertIsNotNone(event_request)
self.assertEqual(event_request.invoke_id, "RID1234")
Expand Down