Skip to content

Commit b9110d6

Browse files
authored
Merge pull request #49 from apoorvkh/logging
basic log streaming
2 parents f081a00 + 7df0948 commit b9110d6

16 files changed

+729
-483
lines changed

.gitignore

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1+
torchrunx_logs/
12
.pixi/
2-
logs/
3-
test_logs/
4-
_build/
5-
out/
6-
output/
3+
.ruff_cache/
4+
.vscode/
75

86
# Byte-compiled / optimized / DLL files
97
__pycache__/

docs/source/advanced.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,29 @@ In addition to ``torchrunx.launch``, we provide the ``torchrunx.Launcher`` datac
1616
launcher.run(distributed_function, {})
1717
1818
.. autoclass:: torchrunx.Launcher
19-
20-
.. autofunction:: torchrunx.Launcher.run
19+
:members:
20+
.. .. autofunction:: torchrunx.Launcher.run
2121
2222
Logging
2323
-------
2424

25-
All logs are generated in the folder provided as the ``logs`` argument to :mod:`torchrunx.launch`. Each worker agent generates a log, named based on the current date and time, followed by the agent hostname. Each worker also has a log, named identically to their agent's log file except for the addition of the worker's local rank at the end of the name. Each agent includes the output from local worker 0 in its log. The launcher renders agent 0's log to ``stdout`` in real time.
25+
Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a :mod:`torchrunx.DefaultLogSpec` is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
26+
27+
Custom logging classes can be subclassed from the :mod:`torchrunx.LogSpec` class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The :mod:`torchrunx.DefaultLogSpec` maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
28+
29+
.. autoclass:: torchrunx.LogSpec
30+
:members:
31+
32+
.. autoclass:: torchrunx.DefaultLogSpec
33+
:members:
2634

2735
..
2836
TODO: example log structure
2937
3038
Worker environment
3139
------------------
3240

33-
The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify which evnironmental variables should be copied to the agents from the launcher environment. By default, it attempts to copy variables related to Python and important packages/technologies that **torchrunx** uses such as PyTorch, NCCL, CUDA, and more. Strings provided are matched with the names of environmental variables using ``fnmatch`` - standard UNIX filename pattern matching. The variables are inserted into the agent environments, and then copied to workers' environments when they are spawned.
41+
The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify which environmental variables should be copied to the agents from the launcher environment. By default, it attempts to copy variables related to Python and important packages/technologies that **torchrunx** uses such as PyTorch, NCCL, CUDA, and more. Strings provided are matched with the names of environmental variables using ``fnmatch`` - standard UNIX filename pattern matching. The variables are inserted into the agent environments, and then copied to workers' environments when they are spawned.
3442

3543
:mod:`torchrunx.launch` also accepts the ``env_file`` argument, which is designed to expose more advanced environmental configuration to the user. When a file is provided as this argument, the launcher will source the file on each node before executing the agent. This allows for custom bash scripts to be provided in the environmental variables, and allows for node-specific environmental variables to be set.
3644

pixi.lock

Lines changed: 341 additions & 296 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
4141
line-length = 100
4242
src = ["src", "tests"]
4343
[tool.ruff.lint]
44-
select = ["E", "F", "I"]
44+
extend-select = ["I"]
4545

4646
[tool.pyright]
4747
include = ["src", "tests"]

src/torchrunx/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
21
from .launcher import Launcher, launch
32

4-
__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"]
3+
__all__ = [
4+
"Launcher",
5+
"launch",
6+
]

src/torchrunx/__main__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
parser = ArgumentParser()
88
parser.add_argument("--launcher-hostname", type=str)
99
parser.add_argument("--launcher-port", type=int)
10+
parser.add_argument("--logger-port", type=int)
1011
parser.add_argument("--world-size", type=int)
1112
parser.add_argument("--rank", type=int)
1213
args = parser.parse_args()
@@ -18,4 +19,8 @@
1819
rank=args.rank,
1920
)
2021

21-
main(launcher_agent_group)
22+
main(
23+
launcher_agent_group,
24+
logger_hostname=args.launcher_hostname,
25+
logger_port=args.logger_port,
26+
)

src/torchrunx/agent.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import datetime
4+
import logging
45
import os
56
import socket
67
import sys
@@ -14,6 +15,7 @@
1415
from torch.distributed.elastic.multiprocessing import start_processes
1516
from typing_extensions import Self
1617

18+
from .logging_utils import log_records_to_socket, redirect_stdio_to_logger
1719
from .utils import (
1820
AgentPayload,
1921
AgentStatus,
@@ -26,14 +28,16 @@
2628
@dataclass
2729
class WorkerArgs:
2830
function: Callable
29-
master_hostname: str
30-
master_port: int
31+
logger_hostname: str
32+
logger_port: int
33+
main_agent_hostname: str
34+
main_agent_port: int
3135
backend: Literal["mpi", "gloo", "nccl", "ucc", None]
3236
rank: int
3337
local_rank: int
3438
local_world_size: int
3539
world_size: int
36-
log_file: os.PathLike
40+
hostname: str
3741
timeout: int
3842

3943
def to_bytes(self) -> bytes:
@@ -44,114 +48,125 @@ def from_bytes(cls, serialized: bytes) -> Self:
4448
return cloudpickle.loads(serialized)
4549

4650

47-
class WorkerTee(object):
48-
def __init__(self, name: os.PathLike | str, mode: str):
49-
self.file = open(name, mode)
50-
self.stdout = sys.stdout
51-
sys.stdout = self
51+
def entrypoint(serialized_worker_args: bytes):
52+
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
5253

53-
def __enter__(self):
54-
return self
54+
logger = logging.getLogger()
5555

56-
def __exit__(self, exception_type, exception_value, exception_traceback):
57-
self.__del__()
56+
log_records_to_socket(
57+
logger=logger,
58+
hostname=worker_args.hostname,
59+
worker_rank=worker_args.local_rank,
60+
logger_hostname=worker_args.logger_hostname,
61+
logger_port=worker_args.logger_port,
62+
)
5863

59-
def __del__(self):
60-
sys.stdout = self.stdout
61-
self.file.close()
64+
redirect_stdio_to_logger(logger)
6265

63-
def write(self, data):
64-
self.file.write(data)
65-
self.stdout.write(data)
66+
store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
67+
host_name=worker_args.main_agent_hostname,
68+
port=worker_args.main_agent_port,
69+
world_size=worker_args.world_size,
70+
is_master=(worker_args.rank == 0),
71+
)
6672

67-
def flush(self):
68-
self.file.flush()
73+
backend = worker_args.backend
74+
if backend is None:
75+
backend = "nccl" if torch.cuda.is_available() else "gloo"
6976

77+
logger.debug(f"using backend: {backend}")
7078

71-
def entrypoint(serialized_worker_args: bytes):
72-
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
79+
dist.init_process_group(
80+
backend=backend,
81+
world_size=worker_args.world_size,
82+
rank=worker_args.rank,
83+
store=store,
84+
timeout=datetime.timedelta(seconds=worker_args.timeout),
85+
)
86+
87+
os.environ["RANK"] = str(worker_args.rank)
88+
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
89+
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
90+
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
91+
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
92+
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
93+
94+
logger.debug(f"executing function: {worker_args.function}")
7395

74-
with WorkerTee(worker_args.log_file, "w"):
75-
store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
76-
host_name=worker_args.master_hostname,
77-
port=worker_args.master_port,
78-
world_size=worker_args.world_size,
79-
is_master=(worker_args.rank == 0),
80-
)
81-
82-
backend = worker_args.backend
83-
if backend is None:
84-
backend = "nccl" if torch.cuda.is_available() else "gloo"
85-
dist.init_process_group(
86-
backend=backend,
87-
world_size=worker_args.world_size,
88-
rank=worker_args.rank,
89-
store=store,
90-
timeout=datetime.timedelta(seconds=worker_args.timeout),
91-
)
92-
93-
os.environ["RANK"] = str(worker_args.rank)
94-
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
95-
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
96-
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
97-
os.environ["MASTER_ADDR"] = worker_args.master_hostname
98-
os.environ["MASTER_PORT"] = str(worker_args.master_port)
99-
100-
return worker_args.function()
101-
102-
103-
def main(launcher_agent_group: LauncherAgentGroup):
96+
r = worker_args.function()
97+
98+
# flush streams
99+
sys.stdout.flush()
100+
sys.stderr.flush()
101+
102+
return r
103+
104+
105+
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
104106
agent_rank = launcher_agent_group.rank - 1
105107

106108
payload = AgentPayload(
107109
hostname=socket.getfqdn(),
108110
port=get_open_port(),
109111
process_id=os.getpid(),
110112
)
111-
# DefaultLogsSpecs(log_dir=None, tee=Std.ALL, local_ranks_filter={0}),
113+
112114
all_payloads = launcher_agent_group.sync_payloads(payload=payload)
113115
launcher_payload: LauncherPayload = all_payloads[0] # pyright: ignore[reportAssignmentType]
114116
main_agent_payload: AgentPayload = all_payloads[1] # pyright: ignore[reportAssignmentType]
115117

116118
hostname = launcher_payload.hostnames[agent_rank]
117119
worker_world_size = launcher_payload.worker_world_size
118120
worker_global_ranks = launcher_payload.worker_global_ranks[agent_rank]
119-
worker_log_files = launcher_payload.worker_log_files[agent_rank]
120121
num_workers = len(worker_global_ranks)
121122

123+
logger = logging.getLogger()
124+
125+
log_records_to_socket(
126+
logger=logger,
127+
hostname=hostname,
128+
worker_rank=None,
129+
logger_hostname=logger_hostname,
130+
logger_port=logger_port,
131+
)
132+
133+
redirect_stdio_to_logger(logger)
134+
122135
if torch.__version__ >= "2.3":
123-
# DefaultLogsSpecs only exists in torch >= 2.3
124136
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
125137

126-
log_arg = DefaultLogsSpecs(log_dir=tempfile.mkdtemp())
138+
log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
127139
else:
128-
log_arg = tempfile.mkdtemp()
140+
log_kwargs = {"log_dir": tempfile.mkdtemp()}
129141

130142
# spawn workers
131143

132144
ctx = start_processes(
133-
f"{hostname}_",
134-
entrypoint,
135-
{
145+
name=f"{hostname}_",
146+
entrypoint=entrypoint,
147+
args={
136148
i: (
137149
WorkerArgs(
138150
function=launcher_payload.fn,
139-
master_hostname=main_agent_payload.hostname,
140-
master_port=main_agent_payload.port,
151+
logger_hostname=logger_hostname,
152+
logger_port=logger_port,
153+
main_agent_hostname=main_agent_payload.hostname,
154+
main_agent_port=main_agent_payload.port,
141155
backend=launcher_payload.backend,
142156
rank=worker_global_ranks[i],
143157
local_rank=i,
144158
local_world_size=num_workers,
145159
world_size=worker_world_size,
146-
log_file=worker_log_files[i],
160+
hostname=launcher_payload.hostnames[agent_rank],
147161
timeout=launcher_payload.timeout,
148162
).to_bytes(),
149163
)
150164
for i in range(num_workers)
151165
},
152-
{i: {} for i in range(num_workers)},
153-
log_arg, # type: ignore
166+
envs={i: {} for i in range(num_workers)},
167+
**log_kwargs, # pyright: ignore [reportArgumentType]
154168
)
169+
logger.info("starting processes")
155170

156171
try:
157172
status = AgentStatus()
@@ -172,3 +187,5 @@ def main(launcher_agent_group: LauncherAgentGroup):
172187
raise
173188
finally:
174189
ctx.close()
190+
sys.stdout.flush()
191+
sys.stderr.flush()

src/torchrunx/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def slurm_workers() -> int:
4040
# TODO: is it possible to allocate uneven GPUs across nodes?
4141
return len(os.environ["SLURM_JOB_GPUS"].split(","))
4242
elif "SLURM_GPUS_PER_NODE" in os.environ:
43-
return int(os.environ['SLURM_GPUS_PER_NODE'])
43+
return int(os.environ["SLURM_GPUS_PER_NODE"])
4444
else:
4545
# TODO: should we assume that we plan to do one worker per CPU?
4646
return int(os.environ["SLURM_CPUS_ON_NODE"])

0 commit comments

Comments
 (0)