diff --git a/awslambdaric/lambda_runtime_marshaller.py b/awslambdaric/lambda_runtime_marshaller.py index b591e69..42ee127 100644 --- a/awslambdaric/lambda_runtime_marshaller.py +++ b/awslambdaric/lambda_runtime_marshaller.py @@ -4,7 +4,7 @@ import decimal import math - +import os import simplejson as json from .lambda_runtime_exception import FaultException @@ -15,7 +15,10 @@ # We also set 'ensure_ascii=False' so that the encoded json contains unicode characters instead of unicode escape sequences class Encoder(json.JSONEncoder): def __init__(self): - super().__init__(use_decimal=False, ensure_ascii=False) + if os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12": + super().__init__(use_decimal=False, ensure_ascii=False) + else: + super().__init__(use_decimal=False) def default(self, obj): if isinstance(obj, decimal.Decimal): diff --git a/requirements/dev.txt b/requirements/dev.txt index c432413..68377ce 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -9,3 +9,4 @@ bandit>=1.6.2 # Test requirements pytest>=3.0.7 mock>=2.0.0 +parameterized>=0.9.0 \ No newline at end of file diff --git a/tests/test_lambda_runtime_marshaller.py b/tests/test_lambda_runtime_marshaller.py index eb8b848..7cd73b4 100644 --- a/tests/test_lambda_runtime_marshaller.py +++ b/tests/test_lambda_runtime_marshaller.py @@ -3,12 +3,35 @@ """ import decimal +import os import unittest - +from parameterized import parameterized from awslambdaric.lambda_runtime_marshaller import to_json class TestLambdaRuntimeMarshaller(unittest.TestCase): + execution_envs = ( + "AWS_Lambda_python3.12", + "AWS_Lambda_python3.11", + "AWS_Lambda_python3.10", + "AWS_Lambda_python3.9", + ) + + envs_lambda_marshaller_ensure_ascii_false = {"AWS_Lambda_python3.12"} + + execution_envs_lambda_marshaller_ensure_ascii_true = tuple( + set(execution_envs).difference(envs_lambda_marshaller_ensure_ascii_false) + ) + execution_envs_lambda_marshaller_ensure_ascii_false = tuple( + envs_lambda_marshaller_ensure_ascii_false + ) + + def setUp(self): + self.org_os_environ = os.environ + + def tearDown(self): + os.environ = self.org_os_environ + def test_to_json_decimal_encoding(self): response = to_json({"pi": decimal.Decimal("3.14159")}) self.assertEqual('{"pi": 3.14159}', response) @@ -38,10 +61,22 @@ def test_json_serializer_is_not_default_json(self): self.assertFalse(hasattr(stock_json, "YOLO")) self.assertTrue(hasattr(simplejson, "YOLO")) - def test_to_json_unicode_encoding(self): + @parameterized.expand(execution_envs_lambda_marshaller_ensure_ascii_false) + def test_to_json_unicode_not_escaped_encoding(self, execution_env): + os.environ = {"AWS_EXECUTION_ENV": execution_env} response = to_json({"price": "£1.00"}) self.assertEqual('{"price": "£1.00"}', response) self.assertNotEqual('{"price": "\\u00a31.00"}', response) self.assertEqual( 19, len(response.encode("utf-8")) ) # would be 23 bytes if a unicode escape was returned + + @parameterized.expand(execution_envs_lambda_marshaller_ensure_ascii_true) + def test_to_json_unicode_is_escaped_encoding(self, execution_env): + os.environ = {"AWS_EXECUTION_ENV": execution_env} + response = to_json({"price": "£1.00"}) + self.assertEqual('{"price": "\\u00a31.00"}', response) + self.assertNotEqual('{"price": "£1.00"}', response) + self.assertEqual( + 23, len(response.encode("utf-8")) + ) # would be 19 bytes if a escaped was returned