diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index 5ad7bb5..a3da58c 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -17,15 +17,20 @@ _DATETIME_FORMAT, _DEFAULT_FRAME_TYPE, _JSON_FRAME_TYPES, + _TEXT_FRAME_TYPES, JsonFormatter, LogFormat, + _format_log_level, + _get_log_level_from_env_var, ) from .lambda_runtime_marshaller import to_json ERROR_LOG_LINE_TERMINATE = "\r" ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0 _AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT")) -_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper() +_AWS_LAMBDA_LOG_LEVEL = _get_log_level_from_env_var( + os.environ.get("AWS_LAMBDA_LOG_LEVEL") +) def _get_handler(handler): @@ -122,7 +127,7 @@ def log_error(error_result, log_sink): ) else: - _ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE + _ERROR_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.ERROR] def log_error(error_result, log_sink): error_description = "[ERROR]" @@ -296,8 +301,22 @@ def __init__(self, log_sink): def emit(self, record): msg = self.format(record) + self.log_sink.log(msg) + - self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None)) +class LambdaLoggerHandlerWithFrameType(logging.Handler): + def __init__(self, log_sink): + super().__init__() + self.log_sink = log_sink + + def emit(self, record): + self.log_sink.log( + self.format(record), + frame_type=( + getattr(record, "_frame_type", None) + or _TEXT_FRAME_TYPES.get(_format_log_level(record)) + ), + ) class LambdaLoggerFilter(logging.Filter): @@ -416,13 +435,14 @@ def create_log_sink(): def _setup_logging(log_format, log_level, log_sink): logging.Formatter.converter = time.gmtime logger = logging.getLogger() - logger_handler = LambdaLoggerHandler(log_sink) + + if log_format == LogFormat.JSON or log_level: + logger_handler = LambdaLoggerHandlerWithFrameType(log_sink) + else: + logger_handler = LambdaLoggerHandler(log_sink) + if log_format == LogFormat.JSON: logger_handler.setFormatter(JsonFormatter()) - - logging.addLevelName(logging.DEBUG, "TRACE") - if log_level in logging._nameToLevel: - logger.setLevel(log_level) else: logger_handler.setFormatter( logging.Formatter( @@ -431,6 +451,9 @@ def _setup_logging(log_format, log_level, log_sink): ) ) + if log_level in logging._nameToLevel: + logger.setLevel(log_level) + logger_handler.addFilter(LambdaLoggerFilter()) logger.addHandler(logger_handler) diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py index f140253..7ed9940 100644 --- a/awslambdaric/lambda_runtime_log_utils.py +++ b/awslambdaric/lambda_runtime_log_utils.py @@ -45,6 +45,10 @@ def from_str(cls, value: str): return cls.TEXT.value +def _get_log_level_from_env_var(log_level): + return {None: "", "TRACE": "DEBUG"}.get(log_level, log_level).upper() + + _JSON_FRAME_TYPES = { logging.NOTSET: 0xA55A0002.to_bytes(4, "big"), logging.DEBUG: 0xA55A000A.to_bytes(4, "big"), @@ -53,12 +57,24 @@ def from_str(cls, value: str): logging.ERROR: 0xA55A0016.to_bytes(4, "big"), logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"), } -_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big") +_TEXT_FRAME_TYPES = { + logging.NOTSET: 0xA55A0003.to_bytes(4, "big"), + logging.DEBUG: 0xA55A000B.to_bytes(4, "big"), + logging.INFO: 0xA55A000F.to_bytes(4, "big"), + logging.WARNING: 0xA55A0013.to_bytes(4, "big"), + logging.ERROR: 0xA55A0017.to_bytes(4, "big"), + logging.CRITICAL: 0xA55A001B.to_bytes(4, "big"), +} +_DEFAULT_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.NOTSET] _json_encoder = json.JSONEncoder(ensure_ascii=False) _encode_json = _json_encoder.encode +def _format_log_level(record: logging.LogRecord) -> int: + return min(50, max(0, record.levelno)) // 10 * 10 + + class JsonFormatter(logging.Formatter): def __init__(self): super().__init__(datefmt=_DATETIME_FORMAT) @@ -90,13 +106,9 @@ def __format_location(record: logging.LogRecord): return f"{record.pathname}:{record.funcName}:{record.lineno}" - @staticmethod - def __format_log_level(record: logging.LogRecord): - record.levelno = min(50, max(0, record.levelno)) // 10 * 10 - record.levelname = logging.getLevelName(record.levelno) - def format(self, record: logging.LogRecord) -> str: - self.__format_log_level(record) + record.levelno = _format_log_level(record) + record.levelname = logging.getLevelName(record.levelno) record._frame_type = _JSON_FRAME_TYPES.get( record.levelno, _JSON_FRAME_TYPES[logging.NOTSET] ) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 5614a2e..ca367fd 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -18,7 +18,7 @@ import awslambdaric.bootstrap as bootstrap from awslambdaric.lambda_runtime_exception import FaultException -from awslambdaric.lambda_runtime_log_utils import LogFormat +from awslambdaric.lambda_runtime_log_utils import LogFormat, _get_log_level_from_env_var from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller @@ -927,7 +927,7 @@ def test_log_error_framed_log_sink(self): content = f.read() frame_type = int.from_bytes(content[:4], "big") - self.assertEqual(frame_type, 0xA55A0003) + self.assertEqual(frame_type, 0xA55A0017) length = int.from_bytes(content[4:8], "big") self.assertEqual(length, len(expected_logged_error.encode("utf8"))) @@ -973,7 +973,7 @@ def test_log_error_indentation_framed_log_sink(self): content = f.read() frame_type = int.from_bytes(content[:4], "big") - self.assertEqual(frame_type, 0xA55A0003) + self.assertEqual(frame_type, 0xA55A0017) length = int.from_bytes(content[4:8], "big") self.assertEqual(length, len(expected_logged_error.encode("utf8"))) @@ -1016,7 +1016,7 @@ def test_log_error_empty_stacktrace_line_framed_log_sink(self): content = f.read() frame_type = int.from_bytes(content[:4], "big") - self.assertEqual(frame_type, 0xA55A0003) + self.assertEqual(frame_type, 0xA55A0017) length = int.from_bytes(content[4:8], "big") self.assertEqual(length, len(expected_logged_error)) @@ -1053,7 +1053,7 @@ def test_log_error_invokeId_line_framed_log_sink(self): content = f.read() frame_type = int.from_bytes(content[:4], "big") - self.assertEqual(frame_type, 0xA55A0003) + self.assertEqual(frame_type, 0xA55A0017) length = int.from_bytes(content[4:8], "big") self.assertEqual(length, len(expected_logged_error)) @@ -1179,14 +1179,13 @@ def test_log_level(self) -> None: (LogFormat.JSON, "WARN", logging.WARNING), (LogFormat.JSON, "ERROR", logging.ERROR), (LogFormat.JSON, "FATAL", logging.CRITICAL), - # Log level is set only for Json format - (LogFormat.TEXT, "TRACE", logging.NOTSET), - (LogFormat.TEXT, "DEBUG", logging.NOTSET), - (LogFormat.TEXT, "INFO", logging.NOTSET), - (LogFormat.TEXT, "WARN", logging.NOTSET), - (LogFormat.TEXT, "ERROR", logging.NOTSET), - (LogFormat.TEXT, "FATAL", logging.NOTSET), - ("Unknown format", "INFO", logging.NOTSET), + (LogFormat.TEXT, "TRACE", logging.DEBUG), + (LogFormat.TEXT, "DEBUG", logging.DEBUG), + (LogFormat.TEXT, "INFO", logging.INFO), + (LogFormat.TEXT, "WARN", logging.WARN), + (LogFormat.TEXT, "ERROR", logging.ERROR), + (LogFormat.TEXT, "FATAL", logging.CRITICAL), + ("Unknown format", "INFO", logging.INFO), # if level is unknown fall back to default (LogFormat.JSON, "Unknown level", logging.NOTSET), ] @@ -1196,11 +1195,67 @@ def test_log_level(self) -> None: logging.getLogger().handlers.clear() logging.getLogger().level = logging.NOTSET - bootstrap._setup_logging(fmt, log_level, bootstrap.StandardLogSink()) + bootstrap._setup_logging( + fmt, + _get_log_level_from_env_var(log_level), + bootstrap.StandardLogSink(), + ) self.assertEqual(expected_level, logging.getLogger().level) +class TestLambdaLoggerHandlerSetup(unittest.TestCase): + @classmethod + def tearDownClass(cls): + importlib.reload(bootstrap) + logging.getLogger().handlers.clear() + logging.getLogger().level = logging.NOTSET + + def test_handler_setup(self, *_): + test_cases = [ + (62, 0xA55A0003, 46, {}), + (133, 0xA55A001A, 117, {"AWS_LAMBDA_LOG_FORMAT": "JSON"}), + (62, 0xA55A001B, 46, {"AWS_LAMBDA_LOG_LEVEL": "INFO"}), + ] + + for total_length, header, message_length, env_vars in test_cases: + with patch.dict( + os.environ, env_vars, clear=True + ), NamedTemporaryFile() as temp_file: + importlib.reload(bootstrap) + logging.getLogger().handlers.clear() + logging.getLogger().level = logging.NOTSET + + before = int(time.time_ns() / 1000) + with bootstrap.FramedTelemetryLogSink( + os.open(temp_file.name, os.O_CREAT | os.O_RDWR) + ) as ls: + bootstrap._setup_logging( + bootstrap._AWS_LAMBDA_LOG_FORMAT, + bootstrap._AWS_LAMBDA_LOG_LEVEL, + ls, + ) + logger = logging.getLogger() + logger.critical("critical") + after = int(time.time_ns() / 1000) + + content = open(temp_file.name, "rb").read() + self.assertEqual(len(content), total_length) + + pos = 0 + frame_type = int.from_bytes(content[pos : pos + 4], "big") + self.assertEqual(frame_type, header) + pos += 4 + + length = int.from_bytes(content[pos : pos + 4], "big") + self.assertEqual(length, message_length) + pos += 4 + + timestamp = int.from_bytes(content[pos : pos + 8], "big") + self.assertTrue(before <= timestamp) + self.assertTrue(timestamp <= after) + + class TestLogging(unittest.TestCase): @classmethod def setUpClass(cls) -> None: