diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4715352..8f41b84 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.10' - run: .ci/scripts/lint.sh test: diff --git a/ecs_logging/_stdlib.py b/ecs_logging/_stdlib.py index 9eff74d..7839a7d 100644 --- a/ecs_logging/_stdlib.py +++ b/ecs_logging/_stdlib.py @@ -15,29 +15,27 @@ # specific language governing permissions and limitations # under the License. +import collections.abc import logging import sys import time +from functools import lru_cache from traceback import format_tb from ._meta import ECS_VERSION from ._utils import ( - TYPE_CHECKING, - collections_abc, de_dot, flatten_dict, json_dumps, - lru_cache, merge_dicts, ) -if TYPE_CHECKING: - from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, Union - try: - from typing import Literal, Union # type: ignore - except ImportError: - from typing_extensions import Literal, Union # type: ignore +try: + from typing import Literal # type: ignore +except ImportError: + from typing_extensions import Literal # type: ignore # Load the attributes of a LogRecord so if some are @@ -78,16 +76,15 @@ class StdlibFormatter(logging.Formatter): converter = time.gmtime def __init__( - self, # type: Any - fmt=None, # type: Optional[str] - datefmt=None, # type: Optional[str] - style="%", # type: Union[Literal["%"], Literal["{"], Literal["$"]] - validate=None, # type: Optional[bool] - stack_trace_limit=None, # type: Optional[int] - extra=None, # type: Optional[Dict[str, Any]] - exclude_fields=(), # type: Sequence[str] - ): - # type: (...) -> None + self, + fmt: Optional[str] = None, + datefmt: Optional[str] = None, + style: Union[Literal["%"], Literal["{"], Literal["$"]] = "%", + validate: Optional[bool] = None, + stack_trace_limit: Optional[int] = None, + extra: Optional[Dict[str, Any]] = None, + exclude_fields: Sequence[str] = (), + ) -> None: """Initialize the ECS formatter. :param int stack_trace_limit: @@ -127,7 +124,7 @@ def __init__( ) if ( - not isinstance(exclude_fields, collections_abc.Sequence) + not isinstance(exclude_fields, collections.abc.Sequence) or isinstance(exclude_fields, str) or any(not isinstance(item, str) for item in exclude_fields) ): @@ -137,8 +134,7 @@ def __init__( self._exclude_fields = frozenset(exclude_fields) self._stack_trace_limit = stack_trace_limit - def _record_error_type(self, record): - # type: (logging.LogRecord) -> Optional[str] + def _record_error_type(self, record: logging.LogRecord) -> Optional[str]: exc_info = record.exc_info if not exc_info: # exc_info is either an iterable or bool. If it doesn't @@ -151,8 +147,7 @@ def _record_error_type(self, record): return exc_info[0].__name__ return None - def _record_error_message(self, record): - # type: (logging.LogRecord) -> Optional[str] + def _record_error_message(self, record: logging.LogRecord) -> Optional[str]: exc_info = record.exc_info if not exc_info: # exc_info is either an iterable or bool. If it doesn't @@ -165,13 +160,11 @@ def _record_error_message(self, record): return str(exc_info[1]) return None - def format(self, record): - # type: (logging.LogRecord) -> str + def format(self, record: logging.LogRecord) -> str: result = self.format_to_ecs(record) return json_dumps(result) - def format_to_ecs(self, record): - # type: (logging.LogRecord) -> Dict[str, Any] + def format_to_ecs(self, record: logging.LogRecord) -> Dict[str, Any]: """Function that can be overridden to add additional fields to (or remove fields from) the JSON before being dumped into a string. @@ -185,7 +178,7 @@ def format_to_ecs(self, record): return result """ - extractors = { + extractors: Dict[str, Callable[[logging.LogRecord], Any]] = { "@timestamp": self._record_timestamp, "ecs.version": lambda _: ECS_VERSION, "log.level": lambda r: (r.levelname.lower() if r.levelname else None), @@ -201,9 +194,9 @@ def format_to_ecs(self, record): "error.type": self._record_error_type, "error.message": self._record_error_message, "error.stack_trace": self._record_error_stack_trace, - } # type: Dict[str, Callable[[logging.LogRecord],Any]] + } - result = {} # type: Dict[str, Any] + result: Dict[str, Any] = {} for field in set(extractors.keys()).difference(self._exclude_fields): if self._is_field_excluded(field): continue @@ -262,8 +255,7 @@ def format_to_ecs(self, record): return result @lru_cache() - def _is_field_excluded(self, field): - # type: (str) -> bool + def _is_field_excluded(self, field: str) -> bool: field_path = [] for path in field.split("."): field_path.append(path) @@ -271,19 +263,18 @@ def _is_field_excluded(self, field): return True return False - def _record_timestamp(self, record): - # type: (logging.LogRecord) -> str + def _record_timestamp(self, record: logging.LogRecord) -> str: return "%s.%03dZ" % ( self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"), record.msecs, ) - def _record_attribute(self, attribute): - # type: (str) -> Callable[[logging.LogRecord], Optional[Any]] + def _record_attribute( + self, attribute: str + ) -> Callable[[logging.LogRecord], Optional[Any]]: return lambda r: getattr(r, attribute, None) - def _record_error_stack_trace(self, record): - # type: (logging.LogRecord) -> Optional[str] + def _record_error_stack_trace(self, record: logging.LogRecord) -> Optional[str]: # Using stack_info=True will add 'error.stack_trace' even # if the type is not 'error', exc_info=True only gathers # when there's an active exception. diff --git a/ecs_logging/_structlog.py b/ecs_logging/_structlog.py index 84877d7..5bc65e5 100644 --- a/ecs_logging/_structlog.py +++ b/ecs_logging/_structlog.py @@ -17,18 +17,16 @@ import time import datetime -from ._meta import ECS_VERSION -from ._utils import json_dumps, normalize_dict, TYPE_CHECKING +from typing import Any, Dict -if TYPE_CHECKING: - from typing import Any, Dict +from ._meta import ECS_VERSION +from ._utils import json_dumps, normalize_dict class StructlogFormatter: """ECS formatter for the ``structlog`` module""" - def __call__(self, _, name, event_dict): - # type: (Any, str, Dict[str, Any]) -> str + def __call__(self, _: Any, name: str, event_dict: Dict[str, Any]) -> str: # Handle event -> message now so that stuff like `event.dataset` doesn't # cause problems down the line @@ -38,8 +36,7 @@ def __call__(self, _, name, event_dict): event_dict = self.format_to_ecs(event_dict) return self._json_dumps(event_dict) - def format_to_ecs(self, event_dict): - # type: (Dict[str, Any]) -> Dict[str, Any] + def format_to_ecs(self, event_dict: Dict[str, Any]) -> Dict[str, Any]: if "@timestamp" not in event_dict: event_dict["@timestamp"] = ( datetime.datetime.fromtimestamp( @@ -58,6 +55,5 @@ def format_to_ecs(self, event_dict): event_dict.setdefault("ecs.version", ECS_VERSION) return event_dict - def _json_dumps(self, value): - # type: (Dict[str, Any]) -> str + def _json_dumps(self, value: Dict[str, Any]) -> str: return json_dumps(value=value) diff --git a/ecs_logging/_utils.py b/ecs_logging/_utils.py index be55e87..ee5dc6b 100644 --- a/ecs_logging/_utils.py +++ b/ecs_logging/_utils.py @@ -15,52 +15,28 @@ # specific language governing permissions and limitations # under the License. +import collections.abc import json import functools - -try: - import typing - - TYPE_CHECKING = typing.TYPE_CHECKING -except ImportError: - typing = None # type: ignore - TYPE_CHECKING = False - -if TYPE_CHECKING: - from typing import Any, Dict - -try: - import collections.abc as collections_abc -except ImportError: - import collections as collections_abc # type: ignore - -try: - from functools import lru_cache -except ImportError: - from backports.functools_lru_cache import lru_cache # type: ignore +from typing import Any, Dict, Mapping __all__ = [ - "collections_abc", "normalize_dict", "de_dot", "merge_dicts", "json_dumps", - "TYPE_CHECKING", - "typing", - "lru_cache", ] -def flatten_dict(value): - # type: (typing.Mapping[str, Any]) -> Dict[str, Any] +def flatten_dict(value: Mapping[str, Any]) -> Dict[str, Any]: """Adds dots to all nested fields in dictionaries. Raises an error if there are entries which are represented with different forms of nesting. (ie {"a": {"b": 1}, "a.b": 2}) """ top_level = {} for key, val in value.items(): - if not isinstance(val, collections_abc.Mapping): + if not isinstance(val, collections.abc.Mapping): if key in top_level: raise ValueError(f"Duplicate entry for '{key}' with different nesting") top_level[key] = val @@ -77,8 +53,7 @@ def flatten_dict(value): return top_level -def normalize_dict(value): - # type: (Dict[str, Any]) -> Dict[str, Any] +def normalize_dict(value: Dict[str, Any]) -> Dict[str, Any]: """Expands all dotted names to nested dictionaries""" if not isinstance(value, dict): return value @@ -94,8 +69,7 @@ def normalize_dict(value): return value -def de_dot(dot_string, msg): - # type: (str, Any) -> Dict[str, Any] +def de_dot(dot_string: str, msg: Any) -> Dict[str, Any]: """Turn value and dotted string key into a nested dictionary""" arr = dot_string.split(".") ret = {arr[-1]: msg} @@ -104,8 +78,7 @@ def de_dot(dot_string, msg): return ret -def merge_dicts(from_, into): - # type: (Dict[Any, Any], Dict[Any, Any]) -> Dict[Any, Any] +def merge_dicts(from_: Dict[Any, Any], into: Dict[Any, Any]) -> Dict[Any, Any]: """Merge deeply nested dictionary structures. When called has side-effects within 'destination'. """ @@ -125,8 +98,7 @@ def merge_dicts(from_, into): return into -def json_dumps(value): - # type: (Dict[str, Any]) -> str +def json_dumps(value: Dict[str, Any]) -> str: # Ensure that the first three fields are '@timestamp', # 'log.level', and 'message' per ECS spec @@ -175,8 +147,7 @@ def json_dumps(value): return json_dumps(value) -def _json_dumps_fallback(value): - # type: (Any) -> Any +def _json_dumps_fallback(value: Any) -> Any: """ Fallback handler for json.dumps to handle objects json doesn't know how to serialize.