diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index 91ebd4c..ba85902 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -52,6 +52,12 @@ class LambdaRuntimeClient(object): def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): self.lambda_runtime_address = lambda_runtime_address self.use_thread_for_polling_next = use_thread_for_polling_next + if self.use_thread_for_polling_next: + # Conditionally import only for the case when TPE is used in this class. + from concurrent.futures import ThreadPoolExecutor + + # Not defining symbol as global to avoid relying on TPE being imported unconditionally. + self.ThreadPoolExecutor = ThreadPoolExecutor def post_init_error(self, error_response_data): # These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`. @@ -74,9 +80,8 @@ def wait_next_invocation(self): # which can then process signals. if self.use_thread_for_polling_next: try: - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor(max_workers=1) as executor: + # TPE class is supposed to be registered at construction time and be ready to use. + with self.ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(runtime_client.next) response_body, headers = future.result() except Exception as e: diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index ca367fd..83d31ee 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -589,10 +589,11 @@ def raise_exception_handler(json_input, lambda_context): self.assertEqual(mock_stdout.getvalue(), error_logs) - @patch("sys.stdout", new_callable=StringIO) + # The order of patches matter. Using MagicMock resets sys.stdout to the default. @patch("importlib.import_module") + @patch("sys.stdout", new_callable=StringIO) def test_handle_event_request_fault_exception_logging_syntax_error( - self, mock_import_module, mock_stdout + self, mock_stdout, mock_import_module ): try: eval("-")