Skip to content

Rewrite type annotations #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.7'
python-version: '3.10'
- run: .ci/scripts/lint.sh

test:
Expand Down
69 changes: 30 additions & 39 deletions ecs_logging/_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,27 @@
# specific language governing permissions and limitations
# under the License.

import collections.abc
import logging
import sys
import time
from functools import lru_cache
from traceback import format_tb

from ._meta import ECS_VERSION
from ._utils import (
TYPE_CHECKING,
collections_abc,
de_dot,
flatten_dict,
json_dumps,
lru_cache,
merge_dicts,
)

if TYPE_CHECKING:
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, Union

try:
from typing import Literal, Union # type: ignore
except ImportError:
from typing_extensions import Literal, Union # type: ignore
try:
from typing import Literal # type: ignore
except ImportError:
from typing_extensions import Literal # type: ignore


# Load the attributes of a LogRecord so if some are
Expand Down Expand Up @@ -78,16 +76,15 @@ class StdlibFormatter(logging.Formatter):
converter = time.gmtime

def __init__(
self, # type: Any
fmt=None, # type: Optional[str]
datefmt=None, # type: Optional[str]
style="%", # type: Union[Literal["%"], Literal["{"], Literal["$"]]
validate=None, # type: Optional[bool]
stack_trace_limit=None, # type: Optional[int]
extra=None, # type: Optional[Dict[str, Any]]
exclude_fields=(), # type: Sequence[str]
):
# type: (...) -> None
self,
fmt: Optional[str] = None,
datefmt: Optional[str] = None,
style: Union[Literal["%"], Literal["{"], Literal["$"]] = "%",
validate: Optional[bool] = None,
stack_trace_limit: Optional[int] = None,
extra: Optional[Dict[str, Any]] = None,
exclude_fields: Sequence[str] = (),
) -> None:
"""Initialize the ECS formatter.

:param int stack_trace_limit:
Expand Down Expand Up @@ -127,7 +124,7 @@ def __init__(
)

if (
not isinstance(exclude_fields, collections_abc.Sequence)
not isinstance(exclude_fields, collections.abc.Sequence)
or isinstance(exclude_fields, str)
or any(not isinstance(item, str) for item in exclude_fields)
):
Expand All @@ -137,8 +134,7 @@ def __init__(
self._exclude_fields = frozenset(exclude_fields)
self._stack_trace_limit = stack_trace_limit

def _record_error_type(self, record):
# type: (logging.LogRecord) -> Optional[str]
def _record_error_type(self, record: logging.LogRecord) -> Optional[str]:
exc_info = record.exc_info
if not exc_info:
# exc_info is either an iterable or bool. If it doesn't
Expand All @@ -151,8 +147,7 @@ def _record_error_type(self, record):
return exc_info[0].__name__
return None

def _record_error_message(self, record):
# type: (logging.LogRecord) -> Optional[str]
def _record_error_message(self, record: logging.LogRecord) -> Optional[str]:
exc_info = record.exc_info
if not exc_info:
# exc_info is either an iterable or bool. If it doesn't
Expand All @@ -165,13 +160,11 @@ def _record_error_message(self, record):
return str(exc_info[1])
return None

def format(self, record):
# type: (logging.LogRecord) -> str
def format(self, record: logging.LogRecord) -> str:
result = self.format_to_ecs(record)
return json_dumps(result)

def format_to_ecs(self, record):
# type: (logging.LogRecord) -> Dict[str, Any]
def format_to_ecs(self, record: logging.LogRecord) -> Dict[str, Any]:
"""Function that can be overridden to add additional fields to
(or remove fields from) the JSON before being dumped into a string.

Expand All @@ -185,7 +178,7 @@ def format_to_ecs(self, record):
return result
"""

extractors = {
extractors: Dict[str, Callable[[logging.LogRecord], Any]] = {
"@timestamp": self._record_timestamp,
"ecs.version": lambda _: ECS_VERSION,
"log.level": lambda r: (r.levelname.lower() if r.levelname else None),
Expand All @@ -201,9 +194,9 @@ def format_to_ecs(self, record):
"error.type": self._record_error_type,
"error.message": self._record_error_message,
"error.stack_trace": self._record_error_stack_trace,
} # type: Dict[str, Callable[[logging.LogRecord],Any]]
}

result = {} # type: Dict[str, Any]
result: Dict[str, Any] = {}
for field in set(extractors.keys()).difference(self._exclude_fields):
if self._is_field_excluded(field):
continue
Expand Down Expand Up @@ -262,28 +255,26 @@ def format_to_ecs(self, record):
return result

@lru_cache()
def _is_field_excluded(self, field):
# type: (str) -> bool
def _is_field_excluded(self, field: str) -> bool:
field_path = []
for path in field.split("."):
field_path.append(path)
if ".".join(field_path) in self._exclude_fields:
return True
return False

def _record_timestamp(self, record):
# type: (logging.LogRecord) -> str
def _record_timestamp(self, record: logging.LogRecord) -> str:
return "%s.%03dZ" % (
self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
record.msecs,
)

def _record_attribute(self, attribute):
# type: (str) -> Callable[[logging.LogRecord], Optional[Any]]
def _record_attribute(
self, attribute: str
) -> Callable[[logging.LogRecord], Optional[Any]]:
return lambda r: getattr(r, attribute, None)

def _record_error_stack_trace(self, record):
# type: (logging.LogRecord) -> Optional[str]
def _record_error_stack_trace(self, record: logging.LogRecord) -> Optional[str]:
# Using stack_info=True will add 'error.stack_trace' even
# if the type is not 'error', exc_info=True only gathers
# when there's an active exception.
Expand Down
16 changes: 6 additions & 10 deletions ecs_logging/_structlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

import time
import datetime
from ._meta import ECS_VERSION
from ._utils import json_dumps, normalize_dict, TYPE_CHECKING
from typing import Any, Dict

if TYPE_CHECKING:
from typing import Any, Dict
from ._meta import ECS_VERSION
from ._utils import json_dumps, normalize_dict


class StructlogFormatter:
"""ECS formatter for the ``structlog`` module"""

def __call__(self, _, name, event_dict):
# type: (Any, str, Dict[str, Any]) -> str
def __call__(self, _: Any, name: str, event_dict: Dict[str, Any]) -> str:

# Handle event -> message now so that stuff like `event.dataset` doesn't
# cause problems down the line
Expand All @@ -38,8 +36,7 @@ def __call__(self, _, name, event_dict):
event_dict = self.format_to_ecs(event_dict)
return self._json_dumps(event_dict)

def format_to_ecs(self, event_dict):
# type: (Dict[str, Any]) -> Dict[str, Any]
def format_to_ecs(self, event_dict: Dict[str, Any]) -> Dict[str, Any]:
if "@timestamp" not in event_dict:
event_dict["@timestamp"] = (
datetime.datetime.fromtimestamp(
Expand All @@ -58,6 +55,5 @@ def format_to_ecs(self, event_dict):
event_dict.setdefault("ecs.version", ECS_VERSION)
return event_dict

def _json_dumps(self, value):
# type: (Dict[str, Any]) -> str
def _json_dumps(self, value: Dict[str, Any]) -> str:
return json_dumps(value=value)
47 changes: 9 additions & 38 deletions ecs_logging/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,28 @@
# specific language governing permissions and limitations
# under the License.

import collections.abc
import json
import functools

try:
import typing

TYPE_CHECKING = typing.TYPE_CHECKING
except ImportError:
typing = None # type: ignore
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Dict

try:
import collections.abc as collections_abc
except ImportError:
import collections as collections_abc # type: ignore

try:
from functools import lru_cache
except ImportError:
from backports.functools_lru_cache import lru_cache # type: ignore
from typing import Any, Dict, Mapping


__all__ = [
"collections_abc",
"normalize_dict",
"de_dot",
"merge_dicts",
"json_dumps",
"TYPE_CHECKING",
"typing",
"lru_cache",
]


def flatten_dict(value):
# type: (typing.Mapping[str, Any]) -> Dict[str, Any]
def flatten_dict(value: Mapping[str, Any]) -> Dict[str, Any]:
"""Adds dots to all nested fields in dictionaries.
Raises an error if there are entries which are represented
with different forms of nesting. (ie {"a": {"b": 1}, "a.b": 2})
"""
top_level = {}
for key, val in value.items():
if not isinstance(val, collections_abc.Mapping):
if not isinstance(val, collections.abc.Mapping):
if key in top_level:
raise ValueError(f"Duplicate entry for '{key}' with different nesting")
top_level[key] = val
Expand All @@ -77,8 +53,7 @@ def flatten_dict(value):
return top_level


def normalize_dict(value):
# type: (Dict[str, Any]) -> Dict[str, Any]
def normalize_dict(value: Dict[str, Any]) -> Dict[str, Any]:
"""Expands all dotted names to nested dictionaries"""
if not isinstance(value, dict):
return value
Expand All @@ -94,8 +69,7 @@ def normalize_dict(value):
return value


def de_dot(dot_string, msg):
# type: (str, Any) -> Dict[str, Any]
def de_dot(dot_string: str, msg: Any) -> Dict[str, Any]:
"""Turn value and dotted string key into a nested dictionary"""
arr = dot_string.split(".")
ret = {arr[-1]: msg}
Expand All @@ -104,8 +78,7 @@ def de_dot(dot_string, msg):
return ret


def merge_dicts(from_, into):
# type: (Dict[Any, Any], Dict[Any, Any]) -> Dict[Any, Any]
def merge_dicts(from_: Dict[Any, Any], into: Dict[Any, Any]) -> Dict[Any, Any]:
"""Merge deeply nested dictionary structures.
When called has side-effects within 'destination'.
"""
Expand All @@ -125,8 +98,7 @@ def merge_dicts(from_, into):
return into


def json_dumps(value):
# type: (Dict[str, Any]) -> str
def json_dumps(value: Dict[str, Any]) -> str:

# Ensure that the first three fields are '@timestamp',
# 'log.level', and 'message' per ECS spec
Expand Down Expand Up @@ -175,8 +147,7 @@ def json_dumps(value):
return json_dumps(value)


def _json_dumps_fallback(value):
# type: (Any) -> Any
def _json_dumps_fallback(value: Any) -> Any:
"""
Fallback handler for json.dumps to handle objects json doesn't know how to
serialize.
Expand Down