Skip to content

Commit b816762

Browse files
fix(core): mypy (#810)
makes mypy happier --------- Co-authored-by: Roy Moore <[email protected]>
1 parent ff6a32d commit b816762

19 files changed

+242
-129
lines changed
Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
# flake8: noqa: F401
22
from testcontainers.compose.compose import (
33
ComposeContainer,
4-
ContainerIsNotRunning,
54
DockerCompose,
6-
NoSuchPortExposed,
7-
PublishedPort,
5+
PublishedPortModel,
86
)
7+
from testcontainers.core.exceptions import ContainerIsNotRunning, NoSuchPortExposed
8+
9+
__all__ = [
10+
"ComposeContainer",
11+
"ContainerIsNotRunning",
12+
"DockerCompose",
13+
"NoSuchPortExposed",
14+
"PublishedPortModel",
15+
]

core/testcontainers/compose/compose.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import asdict, dataclass, field, fields
1+
from dataclasses import asdict, dataclass, field, fields, is_dataclass
22
from functools import cached_property
33
from json import loads
44
from logging import warning
@@ -7,6 +7,7 @@
77
from re import split
88
from subprocess import CompletedProcess
99
from subprocess import run as subprocess_run
10+
from types import TracebackType
1011
from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast
1112
from urllib.error import HTTPError, URLError
1213
from urllib.request import urlopen
@@ -18,35 +19,37 @@
1819
_WARNINGS = {"DOCKER_COMPOSE_GET_CONFIG": "get_config is experimental, see testcontainers/testcontainers-python#669"}
1920

2021

21-
def _ignore_properties(cls: type[_IPT], dict_: any) -> _IPT:
22+
def _ignore_properties(cls: type[_IPT], dict_: Any) -> _IPT:
2223
"""omits extra fields like @JsonIgnoreProperties(ignoreUnknown = true)
2324
2425
https://gist.github.com/alexanderankin/2a4549ac03554a31bef6eaaf2eaf7fd5"""
2526
if isinstance(dict_, cls):
2627
return dict_
28+
if not is_dataclass(cls):
29+
raise TypeError(f"Expected a dataclass type, got {cls}")
2730
class_fields = {f.name for f in fields(cls)}
2831
filtered = {k: v for k, v in dict_.items() if k in class_fields}
29-
return cls(**filtered)
32+
return cast("_IPT", cls(**filtered))
3033

3134

3235
@dataclass
33-
class PublishedPort:
36+
class PublishedPortModel:
3437
"""
3538
Class that represents the response we get from compose when inquiring status
3639
via `DockerCompose.get_running_containers()`.
3740
"""
3841

3942
URL: Optional[str] = None
40-
TargetPort: Optional[str] = None
41-
PublishedPort: Optional[str] = None
43+
TargetPort: Optional[int] = None
44+
PublishedPort: Optional[int] = None
4245
Protocol: Optional[str] = None
4346

44-
def normalize(self):
47+
def normalize(self) -> "PublishedPortModel":
4548
url_not_usable = system() == "Windows" and self.URL == "0.0.0.0"
4649
if url_not_usable:
4750
self_dict = asdict(self)
4851
self_dict.update({"URL": "127.0.0.1"})
49-
return PublishedPort(**self_dict)
52+
return PublishedPortModel(**self_dict)
5053
return self
5154

5255

@@ -75,19 +78,19 @@ class ComposeContainer:
7578
Service: Optional[str] = None
7679
State: Optional[str] = None
7780
Health: Optional[str] = None
78-
ExitCode: Optional[str] = None
79-
Publishers: list[PublishedPort] = field(default_factory=list)
81+
ExitCode: Optional[int] = None
82+
Publishers: list[PublishedPortModel] = field(default_factory=list)
8083

81-
def __post_init__(self):
84+
def __post_init__(self) -> None:
8285
if self.Publishers:
83-
self.Publishers = [_ignore_properties(PublishedPort, p) for p in self.Publishers]
86+
self.Publishers = [_ignore_properties(PublishedPortModel, p) for p in self.Publishers]
8487

8588
def get_publisher(
8689
self,
8790
by_port: Optional[int] = None,
8891
by_host: Optional[str] = None,
89-
prefer_ip_version: Literal["IPV4", "IPv6"] = "IPv4",
90-
) -> PublishedPort:
92+
prefer_ip_version: Literal["IPv4", "IPv6"] = "IPv4",
93+
) -> PublishedPortModel:
9194
remaining_publishers = self.Publishers
9295

9396
remaining_publishers = [r for r in remaining_publishers if self._matches_protocol(prefer_ip_version, r)]
@@ -109,8 +112,9 @@ def get_publisher(
109112
)
110113

111114
@staticmethod
112-
def _matches_protocol(prefer_ip_version, r):
113-
return (":" in r.URL) is (prefer_ip_version == "IPv6")
115+
def _matches_protocol(prefer_ip_version: str, r: PublishedPortModel) -> bool:
116+
r_url = r.URL
117+
return (r_url is not None and ":" in r_url) is (prefer_ip_version == "IPv6")
114118

115119

116120
@dataclass
@@ -164,7 +168,7 @@ class DockerCompose:
164168
image: "hello-world"
165169
"""
166170

167-
context: Union[str, PathLike]
171+
context: Union[str, PathLike[str]]
168172
compose_file_name: Optional[Union[str, list[str]]] = None
169173
pull: bool = False
170174
build: bool = False
@@ -175,15 +179,17 @@ class DockerCompose:
175179
docker_command_path: Optional[str] = None
176180
profiles: Optional[list[str]] = None
177181

178-
def __post_init__(self):
182+
def __post_init__(self) -> None:
179183
if isinstance(self.compose_file_name, str):
180184
self.compose_file_name = [self.compose_file_name]
181185

182186
def __enter__(self) -> "DockerCompose":
183187
self.start()
184188
return self
185189

186-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
190+
def __exit__(
191+
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
192+
) -> None:
187193
self.stop(not self.keep_volumes)
188194

189195
def docker_compose_command(self) -> list[str]:
@@ -235,7 +241,7 @@ def start(self) -> None:
235241

236242
self._run_command(cmd=up_cmd)
237243

238-
def stop(self, down=True) -> None:
244+
def stop(self, down: bool = True) -> None:
239245
"""
240246
Stops the docker compose environment.
241247
"""
@@ -295,7 +301,7 @@ def get_config(
295301
cmd_output = self._run_command(cmd=config_cmd).stdout
296302
return cast(dict[str, Any], loads(cmd_output)) # noqa: TC006
297303

298-
def get_containers(self, include_all=False) -> list[ComposeContainer]:
304+
def get_containers(self, include_all: bool = False) -> list[ComposeContainer]:
299305
"""
300306
Fetch information about running containers via `docker compose ps --format json`.
301307
Available only in V2 of compose.
@@ -370,17 +376,18 @@ def exec_in_container(
370376
"""
371377
if not service_name:
372378
service_name = self.get_container().Service
373-
exec_cmd = [*self.compose_command_property, "exec", "-T", service_name, *command]
379+
assert service_name
380+
exec_cmd: list[str] = [*self.compose_command_property, "exec", "-T", service_name, *command]
374381
result = self._run_command(cmd=exec_cmd)
375382

376-
return (result.stdout.decode("utf-8"), result.stderr.decode("utf-8"), result.returncode)
383+
return result.stdout.decode("utf-8"), result.stderr.decode("utf-8"), result.returncode
377384

378385
def _run_command(
379386
self,
380387
cmd: Union[str, list[str]],
381388
context: Optional[str] = None,
382389
) -> CompletedProcess[bytes]:
383-
context = context or self.context
390+
context = context or str(self.context)
384391
return subprocess_run(
385392
cmd,
386393
capture_output=True,
@@ -392,7 +399,7 @@ def get_service_port(
392399
self,
393400
service_name: Optional[str] = None,
394401
port: Optional[int] = None,
395-
):
402+
) -> Optional[int]:
396403
"""
397404
Returns the mapped port for one of the services.
398405
@@ -408,13 +415,14 @@ def get_service_port(
408415
str:
409416
The mapped port on the host
410417
"""
411-
return self.get_container(service_name).get_publisher(by_port=port).normalize().PublishedPort
418+
normalize: PublishedPortModel = self.get_container(service_name).get_publisher(by_port=port).normalize()
419+
return normalize.PublishedPort
412420

413421
def get_service_host(
414422
self,
415423
service_name: Optional[str] = None,
416424
port: Optional[int] = None,
417-
):
425+
) -> Optional[str]:
418426
"""
419427
Returns the host for one of the services.
420428
@@ -430,13 +438,17 @@ def get_service_host(
430438
str:
431439
The hostname for the service
432440
"""
433-
return self.get_container(service_name).get_publisher(by_port=port).normalize().URL
441+
container: ComposeContainer = self.get_container(service_name)
442+
publisher: PublishedPortModel = container.get_publisher(by_port=port)
443+
normalize: PublishedPortModel = publisher.normalize()
444+
url: Optional[str] = normalize.URL
445+
return url
434446

435447
def get_service_host_and_port(
436448
self,
437449
service_name: Optional[str] = None,
438450
port: Optional[int] = None,
439-
):
451+
) -> tuple[Optional[str], Optional[int]]:
440452
publisher = self.get_container(service_name).get_publisher(by_port=port).normalize()
441453
return publisher.URL, publisher.PublishedPort
442454

core/testcontainers/core/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from os import environ
88
from os.path import exists
99
from pathlib import Path
10-
from typing import Final, Optional, Union
10+
from typing import Final, Optional, Union, cast
1111

1212
import docker
1313

@@ -39,6 +39,7 @@ def get_docker_socket() -> str:
3939
try:
4040
client = docker.from_env()
4141
socket_path = client.api.get_adapter(client.api.base_url).socket_path
42+
socket_path = cast("str", socket_path)
4243
# return the normalized path as string
4344
return str(Path(socket_path).absolute())
4445
except Exception:
@@ -148,6 +149,7 @@ def ryuk_docker_socket(self, value: str) -> None:
148149

149150
__all__ = [
150151
# Public API of this module:
152+
"ConnectionMode",
151153
"testcontainers_config",
152154
]
153155

0 commit comments

Comments
 (0)