Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions awslambdaric/lambda_runtime_marshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import decimal
import math

import os
import simplejson as json

from .lambda_runtime_exception import FaultException
Expand All @@ -15,7 +15,10 @@
# We also set 'ensure_ascii=False' so that the encoded json contains unicode characters instead of unicode escape sequences
class Encoder(json.JSONEncoder):
def __init__(self):
super().__init__(use_decimal=False, ensure_ascii=False)
if os.environ.get("AWS_EXECUTION_ENV") in {"AWS_Lambda_python3.12"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicks:

if os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12":

to avoid extra set object creation

super().__init__(use_decimal=False, ensure_ascii=False)
else:
super().__init__(use_decimal=False)

def default(self, obj):
if isinstance(obj, decimal.Decimal):
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ bandit>=1.6.2
# Test requirements
pytest>=3.0.7
mock>=2.0.0
parameterized>=0.9.0
42 changes: 40 additions & 2 deletions tests/test_lambda_runtime_marshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,33 @@

import decimal
import unittest

import os
from parameterized import parameterized
from awslambdaric.lambda_runtime_marshaller import to_json

execution_envs = (
"AWS_Lambda_python3.12",
"AWS_Lambda_python3.11",
"AWS_Lambda_python3.10",
"AWS_Lambda_python3.9",
)
envs_is_enabled_lambda_marshaller_ensure_ascii_false = {"AWS_Lambda_python3.12"}

execution_envs_not_enabled_lambda_marshaller_ensure_ascii_false = tuple(
set(execution_envs).difference(envs_is_enabled_lambda_marshaller_ensure_ascii_false)
)
execution_envs_is_enabled_lambda_marshaller_ensure_ascii_false = tuple(
envs_is_enabled_lambda_marshaller_ensure_ascii_false
)


class TestLambdaRuntimeMarshaller(unittest.TestCase):
def setUp(self):
self.org_os_environ = os.environ

def tearDown(self):
os.environ = self.org_os_environ

def test_to_json_decimal_encoding(self):
response = to_json({"pi": decimal.Decimal("3.14159")})
self.assertEqual('{"pi": 3.14159}', response)
Expand Down Expand Up @@ -38,10 +60,26 @@ def test_json_serializer_is_not_default_json(self):
self.assertFalse(hasattr(stock_json, "YOLO"))
self.assertTrue(hasattr(simplejson, "YOLO"))

def test_to_json_unicode_encoding(self):
@parameterized.expand(
execution_envs_is_enabled_lambda_marshaller_ensure_ascii_false
)
def test_to_json_unicode_not_escaped_encoding(self, execution_env):
os.environ = {"AWS_EXECUTION_ENV": execution_env}
response = to_json({"price": "£1.00"})
self.assertEqual('{"price": "£1.00"}', response)
self.assertNotEqual('{"price": "\\u00a31.00"}', response)
self.assertEqual(
19, len(response.encode("utf-8"))
) # would be 23 bytes if a unicode escape was returned

@parameterized.expand(
execution_envs_not_enabled_lambda_marshaller_ensure_ascii_false
)
def test_to_json_unicode_is_escaped_encoding(self, execution_env):
os.environ = {"AWS_EXECUTION_ENV": execution_env}
response = to_json({"price": "£1.00"})
self.assertEqual('{"price": "\\u00a31.00"}', response)
self.assertNotEqual('{"price": "£1.00"}', response)
self.assertEqual(
23, len(response.encode("utf-8"))
) # would be 19 bytes if a escaped was returned