Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,14 @@ def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)

use_thread_for_polling_next = (
os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12"
)

with create_log_sink() as log_sink:
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)
lambda_runtime_client = LambdaRuntimeClient(
lambda_runtime_api_addr, use_thread_for_polling_next
)

try:
_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
Expand Down
25 changes: 20 additions & 5 deletions awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import sys
from concurrent.futures import ThreadPoolExecutor
from awslambdaric import __version__


Expand Down Expand Up @@ -49,8 +48,9 @@ class LambdaRuntimeClient(object):
and response. It allows for function authors to override the the default implementation, LambdaMarshaller which
unmarshals and marshals JSON, to an instance of a class that implements the same interface."""

def __init__(self, lambda_runtime_address):
def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
self.lambda_runtime_address = lambda_runtime_address
self.use_thread_for_polling_next = use_thread_for_polling_next

def post_init_error(self, error_response_data):
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
Expand All @@ -69,9 +69,24 @@ def post_init_error(self, error_response_data):
raise LambdaRuntimeClientError(endpoint, response.code, response_body)

def wait_next_invocation(self):
with ThreadPoolExecutor() as e:
fut = e.submit(runtime_client.next)
response_body, headers = fut.result()
# Calling runtime_client.next() from a separate thread unblocks the main thread,
# which can then process signals.
if self.use_thread_for_polling_next:
from concurrent.futures import ThreadPoolExecutor
from .lambda_runtime_exception import FaultException

try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(runtime_client.next)
response_body, headers = future.result()
except Exception as e:
raise FaultExceptions(
FaultExceptions.LAMBDA_RUNTIME_CLIENT_ERROR,
"LAMBDA_RUNTIME Failed to get next invocation: {}".format(str(e)),
None,
)
else:
response_body, headers = runtime_client.next()
return InvocationRequest(
invoke_id=headers.get("Lambda-Runtime-Aws-Request-Id"),
x_amzn_trace_id=headers.get("Lambda-Runtime-Trace-Id"),
Expand Down
1 change: 1 addition & 0 deletions awslambdaric/lambda_runtime_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FaultException(Exception):
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"

def __init__(self, exception_type, msg, trace=None):
self.msg = msg
Expand Down
15 changes: 15 additions & 0 deletions tests/test_lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def test_wait_next_invocation(self, mock_runtime_client):
self.assertEqual(event_request.content_type, "application/json")
self.assertEqual(event_request.event_body, response_body)

# Using ThreadPoolExecutor to polling next()
runtime_client = LambdaRuntimeClient("localhost:1234", True)

event_request = runtime_client.wait_next_invocation()

self.assertIsNotNone(event_request)
self.assertEqual(event_request.invoke_id, "RID1234")
self.assertEqual(event_request.x_amzn_trace_id, "TID1234")
self.assertEqual(event_request.invoked_function_arn, "FARN1234")
self.assertEqual(event_request.deadline_time_in_ms, 12)
self.assertEqual(event_request.client_context, "client_context")
self.assertEqual(event_request.cognito_identity, "cognito_identity")
self.assertEqual(event_request.content_type, "application/json")
self.assertEqual(event_request.event_body, response_body)

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_post_init_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
Expand Down