Skip to content

Commit 2f6e00e

Browse files
authored
fix (#79)
command launch fixed
1 parent 2b34ffb commit 2f6e00e

5 files changed

Lines changed: 168 additions & 24 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
[build-system]
2+
requires = ["setuptools>=61", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
15
[project]
26
name = "traceml-ai"
3-
version = "0.2.9"
7+
version = "0.2.10"
48

59
description = "TraceML: Lightweight training runtime health monitor."
610
authors = [{ name = "Abhinav Srivastav", email = "abhinav@traceopt.ai" }]

src/traceml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
trace_step,
77
)
88

9-
__version__ = "0.2.9"
9+
__version__ = "0.2.10"
1010

1111
__all__ = [
1212
"__version__",

src/traceml/cli.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import msgspec
1616

17+
from traceml.runtime.launch_context import LaunchContext
1718
from traceml.runtime.session import get_session_id
1819
from traceml.utils.ast_analysis import analyze_script, build_code_manifest
1920

@@ -200,6 +201,7 @@ def write_run_manifest(
200201
nproc_per_node: int,
201202
history_enabled: bool,
202203
status: str,
204+
launch_cwd: str,
203205
aggregator_dir: Optional[Path] = None,
204206
db_path: Optional[Path] = None,
205207
extra: Optional[Dict[str, Any]] = None,
@@ -233,6 +235,7 @@ def write_run_manifest(
233235
"tcp_port": int(tcp_port),
234236
"nproc_per_node": int(nproc_per_node),
235237
"history_enabled": bool(history_enabled),
238+
"launch_cwd": str(Path(launch_cwd).resolve()),
236239
},
237240
"paths": {
238241
"session_root": str(session_root),
@@ -389,10 +392,14 @@ def _handler(signum: int, _frame: Any) -> None:
389392
signal.signal(signal.SIGTERM, _handler)
390393

391394

392-
def start_aggregator_process(env: Dict[str, str]) -> subprocess.Popen:
395+
def start_aggregator_process(
396+
env: Dict[str, str], cwd: str
397+
) -> subprocess.Popen:
393398
"""Start the TraceML aggregator as a separate process.
394399
395-
Stdout/stderr are inherited so Rich output and tracebacks remain visible.
400+
The subprocess cwd is set explicitly so all child processes inherit a
401+
deterministic working directory rather than depending on ambient shell
402+
state.
396403
"""
397404
aggregator_path = (
398405
Path(__file__).parent / "aggregator" / "aggregator_main.py"
@@ -404,18 +411,29 @@ def start_aggregator_process(env: Dict[str, str]) -> subprocess.Popen:
404411

405412
cmd = [sys.executable, str(aggregator_path)]
406413
print("[TraceML] Launching TraceML aggregator:", " ".join(cmd))
407-
return subprocess.Popen(cmd, env=env, start_new_session=True)
414+
return subprocess.Popen(
415+
cmd,
416+
env=env,
417+
cwd=cwd,
418+
start_new_session=True,
419+
)
408420

409421

410422
def start_training_process(
411-
train_cmd: list[str], env: Dict[str, str]
423+
train_cmd: list[str], env: Dict[str, str], cwd: str
412424
) -> subprocess.Popen:
413425
"""Start the training process in a new process group.
414426
415-
Stdout/stderr are inherited so user logs and tracebacks remain visible.
427+
The subprocess cwd is set explicitly so worker processes see the same
428+
working directory the user launched TraceML from.
416429
"""
417430
print("[TraceML] Launching TraceML executor:", " ".join(train_cmd))
418-
return subprocess.Popen(train_cmd, env=env, start_new_session=True)
431+
return subprocess.Popen(
432+
train_cmd,
433+
env=env,
434+
cwd=cwd,
435+
start_new_session=True,
436+
)
419437

420438

421439
def launch_process(script_path: str, args: argparse.Namespace) -> None:
@@ -430,6 +448,7 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
430448
5. Keep training as the primary process; aggregator may fail open
431449
6. On shutdown, terminate child process groups and update the manifest
432450
"""
451+
433452
env = os.environ.copy()
434453
env["PYTHONUNBUFFERED"] = "1"
435454

@@ -452,6 +471,10 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
452471
env["TRACEML_NPROC_PER_NODE"] = str(args.nproc_per_node)
453472
env["TRACEML_HISTORY_ENABLED"] = "0" if args.no_history else "1"
454473

474+
launch_context = LaunchContext.capture()
475+
env.update(launch_context.to_env())
476+
execution_cwd = launch_context.launch_cwd
477+
455478
session_id = env["TRACEML_SESSION_ID"]
456479
session_root = Path(args.logs_dir).resolve() / session_id
457480
aggregator_dir = session_root / "aggregator"
@@ -474,6 +497,7 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
474497
nproc_per_node=args.nproc_per_node,
475498
history_enabled=not args.no_history,
476499
status="starting",
500+
launch_cwd=execution_cwd,
477501
aggregator_dir=aggregator_dir,
478502
db_path=db_path,
479503
extra=(
@@ -495,7 +519,11 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
495519
str(script_path),
496520
*script_args,
497521
]
498-
train_proc = start_training_process(train_cmd=train_cmd, env=env)
522+
train_proc = start_training_process(
523+
train_cmd=train_cmd,
524+
env=env,
525+
cwd=execution_cwd,
526+
)
499527
install_shutdown_handlers(
500528
lambda: (train_proc, None), manifest_path=manifest_path
501529
)
@@ -530,7 +558,7 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
530558
f"(ui={args.mode}, profile={env['TRACEML_PROFILE']})"
531559
)
532560
try:
533-
agg_proc = start_aggregator_process(env=env)
561+
agg_proc = start_aggregator_process(env=env, cwd=execution_cwd)
534562
except FileNotFoundError as exc:
535563
print(f"[TraceML] ERROR: {exc}", file=sys.stderr)
536564
update_run_manifest(manifest_path, status="failed")
@@ -558,7 +586,11 @@ def launch_process(script_path: str, args: argparse.Namespace) -> None:
558586
print("[TraceML] Aggregator ready.")
559587
update_run_manifest(manifest_path, status="running")
560588

561-
train_proc = start_training_process(train_cmd=train_cmd, env=env)
589+
train_proc = start_training_process(
590+
train_cmd=train_cmd,
591+
env=env,
592+
cwd=execution_cwd,
593+
)
562594

563595
while True:
564596
train_rc = train_proc.poll()

src/traceml/runtime/executor.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
from pathlib import Path
3131
from typing import Any, Dict, Optional, Union
3232

33+
from traceml.runtime.launch_context import (
34+
LaunchContext,
35+
script_execution_context,
36+
)
3337
from traceml.runtime.runtime import TraceMLRuntime
3438
from traceml.runtime.settings import TraceMLSettings, TraceMLTCPSettings
3539
from traceml.utils.shared_utils import EXECUTION_LAYER
@@ -344,21 +348,21 @@ def stop_runtime(
344348

345349
def run_user_script(script_path: str, script_args: list[str]) -> None:
346350
"""
347-
Execute the user script in-process using runpy.
351+
Execute the user script in-process using ``runpy``.
348352
349-
This intentionally does not spawn another subprocess so that:
350-
- hooks attach to the real Python objects
351-
- stack traces remain meaningful
352-
- execution context matches the user's script as closely as possible
353-
354-
Global sys.argv is restored afterward to avoid leaking state.
353+
TraceML preserves the user's original launch cwd while also exposing the
354+
script directory as ``sys.path[0]`` so local imports and relative file
355+
access behave like a normal ``python`` or ``torchrun`` script launch.
355356
"""
356-
old_argv = sys.argv[:]
357-
try:
358-
sys.argv = [script_path, *script_args]
359-
runpy.run_path(script_path, run_name="__main__")
360-
finally:
361-
sys.argv = old_argv
357+
launch_context = LaunchContext.from_env()
358+
resolved_script_path = str(Path(script_path).resolve())
359+
360+
with script_execution_context(
361+
script_path=resolved_script_path,
362+
script_args=script_args,
363+
launch_context=launch_context,
364+
):
365+
runpy.run_path(resolved_script_path, run_name="__main__")
362366

363367

364368
def report_crash(cfg: Dict[str, Any], error: BaseException) -> None:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Helpers for preserving the user's original launch context.
3+
4+
TraceML should instrument user code without changing how that code resolves:
5+
6+
- current working directory
7+
- sibling imports
8+
- script argv
9+
10+
This module centralizes that behavior so the CLI launcher and runtime
11+
executor stay small and consistent.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import os
17+
import sys
18+
from contextlib import contextmanager
19+
from dataclasses import dataclass
20+
from pathlib import Path
21+
from typing import Dict, Iterator
22+
23+
TRACEML_LAUNCH_CWD_ENV = "TRACEML_LAUNCH_CWD"
24+
25+
26+
@dataclass(frozen=True)
27+
class LaunchContext:
28+
"""Execution context captured from the user's original CLI invocation.
29+
30+
Attributes
31+
----------
32+
launch_cwd:
33+
The directory where the user invoked ``traceml``.
34+
"""
35+
36+
launch_cwd: str
37+
38+
@classmethod
39+
def capture(cls) -> "LaunchContext":
40+
"""Capture the current user-facing launch context.
41+
42+
This should be called in the main CLI process before spawning any
43+
subprocesses so child processes can faithfully reproduce the user's
44+
original execution environment.
45+
"""
46+
return cls(launch_cwd=str(Path.cwd().resolve()))
47+
48+
@classmethod
49+
def from_env(cls) -> "LaunchContext":
50+
"""Load a previously captured launch context from environment vars.
51+
52+
Falls back to the current process cwd when no explicit launch cwd has
53+
been provided. This keeps behavior safe for tests and direct executor
54+
invocation.
55+
"""
56+
raw = os.environ.get(TRACEML_LAUNCH_CWD_ENV, "").strip()
57+
if raw:
58+
return cls(launch_cwd=str(Path(raw).resolve()))
59+
return cls.capture()
60+
61+
def to_env(self) -> Dict[str, str]:
62+
"""Serialize this launch context into environment variables."""
63+
return {TRACEML_LAUNCH_CWD_ENV: self.launch_cwd}
64+
65+
66+
@contextmanager
67+
def script_execution_context(
68+
*,
69+
script_path: str,
70+
script_args: list[str],
71+
launch_context: LaunchContext,
72+
) -> Iterator[None]:
73+
"""Temporarily apply Python script execution semantics.
74+
75+
Behavior
76+
--------
77+
- ``sys.argv`` looks like a direct script launch
78+
- ``sys.path[0]`` points at the script directory
79+
- process cwd matches the user's original launch cwd
80+
81+
This combination most closely matches normal ``python train.py`` or
82+
``torchrun train.py`` behavior from the user's project directory.
83+
"""
84+
resolved_script_path = str(Path(script_path).resolve())
85+
script_dir = str(Path(resolved_script_path).parent)
86+
87+
old_argv = sys.argv[:]
88+
old_path = sys.path[:]
89+
old_cwd = os.getcwd()
90+
91+
try:
92+
sys.argv = [resolved_script_path, *script_args]
93+
94+
if sys.path:
95+
sys.path[0] = script_dir
96+
else:
97+
sys.path = [script_dir]
98+
99+
os.chdir(launch_context.launch_cwd)
100+
yield
101+
finally:
102+
os.chdir(old_cwd)
103+
sys.argv = old_argv
104+
sys.path = old_path

0 commit comments

Comments
 (0)