16
16
from dataclasses import dataclass
17
17
from functools import partial , reduce
18
18
from logging import Handler
19
- from multiprocessing import Process
19
+ from multiprocessing import Event , Process
20
20
from operator import add
21
21
from pathlib import Path
22
22
from typing import Any , Callable , Literal
34
34
ExceptionFromWorker ,
35
35
WorkerFailedError ,
36
36
)
37
- from .utils .logging import LogRecordSocketReceiver , default_handlers
37
+ from .utils .logging import LoggingServerArgs , start_logging_server
38
38
39
39
40
40
@dataclass
@@ -76,27 +76,32 @@ def run( # noqa: C901, PLR0912
76
76
77
77
launcher_hostname = socket .getfqdn ()
78
78
launcher_port = get_open_port ()
79
+ logging_port = get_open_port ()
79
80
world_size = len (hostnames ) + 1
80
81
81
- log_receiver = None
82
+ stop_logging_event = None
82
83
log_process = None
83
84
launcher_agent_group = None
84
85
agent_payloads = None
85
86
86
87
try :
87
88
# Start logging server (recieves LogRecords from agents/workers)
88
89
89
- log_receiver = _build_logging_server (
90
+ logging_server_args = LoggingServerArgs (
90
91
log_handlers = log_handlers ,
91
- launcher_hostname = launcher_hostname ,
92
+ logging_hostname = launcher_hostname ,
93
+ logging_port = logging_port ,
92
94
hostnames = hostnames ,
93
95
workers_per_host = workers_per_host ,
94
96
log_dir = Path (os .environ .get ("TORCHRUNX_LOG_DIR" , "torchrunx_logs" )),
95
97
log_level = logging ._nameToLevel [os .environ .get ("TORCHRUNX_LOG_LEVEL" , "INFO" )], # noqa: SLF001
96
98
)
97
99
100
+ stop_logging_event = Event ()
101
+
98
102
log_process = Process (
99
- target = log_receiver .serve_forever ,
103
+ target = start_logging_server ,
104
+ args = (logging_server_args .serialize (), stop_logging_event ),
100
105
daemon = True ,
101
106
)
102
107
@@ -109,7 +114,7 @@ def run( # noqa: C901, PLR0912
109
114
command = _build_launch_command (
110
115
launcher_hostname = launcher_hostname ,
111
116
launcher_port = launcher_port ,
112
- logger_port = log_receiver . port ,
117
+ logger_port = logging_port ,
113
118
world_size = world_size ,
114
119
rank = i + 1 ,
115
120
env_vars = (self .default_env_vars + self .extra_env_vars ),
@@ -166,11 +171,10 @@ def run( # noqa: C901, PLR0912
166
171
if all (s .state == "done" for s in agent_statuses ):
167
172
break
168
173
finally :
169
- if log_receiver is not None :
170
- log_receiver .shutdown ()
171
- if log_process is not None :
172
- log_receiver .server_close ()
173
- log_process .kill ()
174
+ if stop_logging_event is not None :
175
+ stop_logging_event .set ()
176
+ if log_process is not None :
177
+ log_process .kill ()
174
178
175
179
if launcher_agent_group is not None :
176
180
launcher_agent_group .shutdown ()
@@ -307,31 +311,6 @@ def _resolve_workers_per_host(
307
311
return workers_per_host
308
312
309
313
310
- def _build_logging_server (
311
- log_handlers : list [Handler ] | Literal ["auto" ] | None ,
312
- launcher_hostname : str ,
313
- hostnames : list [str ],
314
- workers_per_host : list [int ],
315
- log_dir : str | os .PathLike ,
316
- log_level : int ,
317
- ) -> LogRecordSocketReceiver :
318
- if log_handlers is None :
319
- log_handlers = []
320
- elif log_handlers == "auto" :
321
- log_handlers = default_handlers (
322
- hostnames = hostnames ,
323
- workers_per_host = workers_per_host ,
324
- log_dir = log_dir ,
325
- log_level = log_level ,
326
- )
327
-
328
- return LogRecordSocketReceiver (
329
- host = launcher_hostname ,
330
- port = get_open_port (),
331
- handlers = log_handlers ,
332
- )
333
-
334
-
335
314
def _build_launch_command (
336
315
launcher_hostname : str ,
337
316
launcher_port : int ,
0 commit comments