1
- from dataclasses import asdict , dataclass , field , fields
1
+ from dataclasses import asdict , dataclass , field , fields , is_dataclass
2
2
from functools import cached_property
3
3
from json import loads
4
4
from logging import warning
7
7
from re import split
8
8
from subprocess import CompletedProcess
9
9
from subprocess import run as subprocess_run
10
+ from types import TracebackType
10
11
from typing import Any , Callable , Literal , Optional , TypeVar , Union , cast
11
12
from urllib .error import HTTPError , URLError
12
13
from urllib .request import urlopen
18
19
_WARNINGS = {"DOCKER_COMPOSE_GET_CONFIG" : "get_config is experimental, see testcontainers/testcontainers-python#669" }
19
20
20
21
21
- def _ignore_properties (cls : type [_IPT ], dict_ : any ) -> _IPT :
22
+ def _ignore_properties (cls : type [_IPT ], dict_ : Any ) -> _IPT :
22
23
"""omits extra fields like @JsonIgnoreProperties(ignoreUnknown = true)
23
24
24
25
https://gist.github.com/alexanderankin/2a4549ac03554a31bef6eaaf2eaf7fd5"""
25
26
if isinstance (dict_ , cls ):
26
27
return dict_
28
+ if not is_dataclass (cls ):
29
+ raise TypeError (f"Expected a dataclass type, got { cls } " )
27
30
class_fields = {f .name for f in fields (cls )}
28
31
filtered = {k : v for k , v in dict_ .items () if k in class_fields }
29
- return cls (** filtered )
32
+ return cast ( "_IPT" , cls (** filtered ) )
30
33
31
34
32
35
@dataclass
33
- class PublishedPort :
36
+ class PublishedPortModel :
34
37
"""
35
38
Class that represents the response we get from compose when inquiring status
36
39
via `DockerCompose.get_running_containers()`.
37
40
"""
38
41
39
42
URL : Optional [str ] = None
40
- TargetPort : Optional [str ] = None
41
- PublishedPort : Optional [str ] = None
43
+ TargetPort : Optional [int ] = None
44
+ PublishedPort : Optional [int ] = None
42
45
Protocol : Optional [str ] = None
43
46
44
- def normalize (self ):
47
+ def normalize (self ) -> "PublishedPortModel" :
45
48
url_not_usable = system () == "Windows" and self .URL == "0.0.0.0"
46
49
if url_not_usable :
47
50
self_dict = asdict (self )
48
51
self_dict .update ({"URL" : "127.0.0.1" })
49
- return PublishedPort (** self_dict )
52
+ return PublishedPortModel (** self_dict )
50
53
return self
51
54
52
55
@@ -75,19 +78,19 @@ class ComposeContainer:
75
78
Service : Optional [str ] = None
76
79
State : Optional [str ] = None
77
80
Health : Optional [str ] = None
78
- ExitCode : Optional [str ] = None
79
- Publishers : list [PublishedPort ] = field (default_factory = list )
81
+ ExitCode : Optional [int ] = None
82
+ Publishers : list [PublishedPortModel ] = field (default_factory = list )
80
83
81
- def __post_init__ (self ):
84
+ def __post_init__ (self ) -> None :
82
85
if self .Publishers :
83
- self .Publishers = [_ignore_properties (PublishedPort , p ) for p in self .Publishers ]
86
+ self .Publishers = [_ignore_properties (PublishedPortModel , p ) for p in self .Publishers ]
84
87
85
88
def get_publisher (
86
89
self ,
87
90
by_port : Optional [int ] = None ,
88
91
by_host : Optional [str ] = None ,
89
- prefer_ip_version : Literal ["IPV4 " , "IPv6" ] = "IPv4" ,
90
- ) -> PublishedPort :
92
+ prefer_ip_version : Literal ["IPv4 " , "IPv6" ] = "IPv4" ,
93
+ ) -> PublishedPortModel :
91
94
remaining_publishers = self .Publishers
92
95
93
96
remaining_publishers = [r for r in remaining_publishers if self ._matches_protocol (prefer_ip_version , r )]
@@ -109,8 +112,9 @@ def get_publisher(
109
112
)
110
113
111
114
@staticmethod
112
- def _matches_protocol (prefer_ip_version , r ):
113
- return (":" in r .URL ) is (prefer_ip_version == "IPv6" )
115
+ def _matches_protocol (prefer_ip_version : str , r : PublishedPortModel ) -> bool :
116
+ r_url = r .URL
117
+ return (r_url is not None and ":" in r_url ) is (prefer_ip_version == "IPv6" )
114
118
115
119
116
120
@dataclass
@@ -164,7 +168,7 @@ class DockerCompose:
164
168
image: "hello-world"
165
169
"""
166
170
167
- context : Union [str , PathLike ]
171
+ context : Union [str , PathLike [ str ] ]
168
172
compose_file_name : Optional [Union [str , list [str ]]] = None
169
173
pull : bool = False
170
174
build : bool = False
@@ -175,15 +179,17 @@ class DockerCompose:
175
179
docker_command_path : Optional [str ] = None
176
180
profiles : Optional [list [str ]] = None
177
181
178
- def __post_init__ (self ):
182
+ def __post_init__ (self ) -> None :
179
183
if isinstance (self .compose_file_name , str ):
180
184
self .compose_file_name = [self .compose_file_name ]
181
185
182
186
def __enter__ (self ) -> "DockerCompose" :
183
187
self .start ()
184
188
return self
185
189
186
- def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
190
+ def __exit__ (
191
+ self , exc_type : Optional [type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
192
+ ) -> None :
187
193
self .stop (not self .keep_volumes )
188
194
189
195
def docker_compose_command (self ) -> list [str ]:
@@ -235,7 +241,7 @@ def start(self) -> None:
235
241
236
242
self ._run_command (cmd = up_cmd )
237
243
238
- def stop (self , down = True ) -> None :
244
+ def stop (self , down : bool = True ) -> None :
239
245
"""
240
246
Stops the docker compose environment.
241
247
"""
@@ -295,7 +301,7 @@ def get_config(
295
301
cmd_output = self ._run_command (cmd = config_cmd ).stdout
296
302
return cast (dict [str , Any ], loads (cmd_output )) # noqa: TC006
297
303
298
- def get_containers (self , include_all = False ) -> list [ComposeContainer ]:
304
+ def get_containers (self , include_all : bool = False ) -> list [ComposeContainer ]:
299
305
"""
300
306
Fetch information about running containers via `docker compose ps --format json`.
301
307
Available only in V2 of compose.
@@ -370,17 +376,18 @@ def exec_in_container(
370
376
"""
371
377
if not service_name :
372
378
service_name = self .get_container ().Service
373
- exec_cmd = [* self .compose_command_property , "exec" , "-T" , service_name , * command ]
379
+ assert service_name
380
+ exec_cmd : list [str ] = [* self .compose_command_property , "exec" , "-T" , service_name , * command ]
374
381
result = self ._run_command (cmd = exec_cmd )
375
382
376
- return ( result .stdout .decode ("utf-8" ), result .stderr .decode ("utf-8" ), result .returncode )
383
+ return result .stdout .decode ("utf-8" ), result .stderr .decode ("utf-8" ), result .returncode
377
384
378
385
def _run_command (
379
386
self ,
380
387
cmd : Union [str , list [str ]],
381
388
context : Optional [str ] = None ,
382
389
) -> CompletedProcess [bytes ]:
383
- context = context or self .context
390
+ context = context or str ( self .context )
384
391
return subprocess_run (
385
392
cmd ,
386
393
capture_output = True ,
@@ -392,7 +399,7 @@ def get_service_port(
392
399
self ,
393
400
service_name : Optional [str ] = None ,
394
401
port : Optional [int ] = None ,
395
- ):
402
+ ) -> Optional [ int ] :
396
403
"""
397
404
Returns the mapped port for one of the services.
398
405
@@ -408,13 +415,14 @@ def get_service_port(
408
415
str:
409
416
The mapped port on the host
410
417
"""
411
- return self .get_container (service_name ).get_publisher (by_port = port ).normalize ().PublishedPort
418
+ normalize : PublishedPortModel = self .get_container (service_name ).get_publisher (by_port = port ).normalize ()
419
+ return normalize .PublishedPort
412
420
413
421
def get_service_host (
414
422
self ,
415
423
service_name : Optional [str ] = None ,
416
424
port : Optional [int ] = None ,
417
- ):
425
+ ) -> Optional [ str ] :
418
426
"""
419
427
Returns the host for one of the services.
420
428
@@ -430,13 +438,17 @@ def get_service_host(
430
438
str:
431
439
The hostname for the service
432
440
"""
433
- return self .get_container (service_name ).get_publisher (by_port = port ).normalize ().URL
441
+ container : ComposeContainer = self .get_container (service_name )
442
+ publisher : PublishedPortModel = container .get_publisher (by_port = port )
443
+ normalize : PublishedPortModel = publisher .normalize ()
444
+ url : Optional [str ] = normalize .URL
445
+ return url
434
446
435
447
def get_service_host_and_port (
436
448
self ,
437
449
service_name : Optional [str ] = None ,
438
450
port : Optional [int ] = None ,
439
- ):
451
+ ) -> tuple [ Optional [ str ], Optional [ int ]] :
440
452
publisher = self .get_container (service_name ).get_publisher (by_port = port ).normalize ()
441
453
return publisher .URL , publisher .PublishedPort
442
454
0 commit comments