Skip to content

Commit b5241af

Browse files
committed
misc edits
1 parent b3a27fb commit b5241af

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

examples/submitit_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main():
3333
output_dir = "output",
3434
do_train = True,
3535
per_device_train_batch_size = 16,
36-
max_steps = 100,
36+
max_steps = 20,
3737
)
3838

3939
trainer = Trainer(

src/torchrunx/launcher.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def launch(
8383

8484
# launch command
8585

86-
env_export_string = " ".join(
87-
f'{k}="{v}"' for k, v in os.environ.items() if any(fnmatch.fnmatch(k, e) for e in env_vars)
88-
)
89-
if env_export_string != "":
90-
env_export_string = f"export {env_export_string} && "
86+
env_export_string = ""
87+
env_exports = []
88+
for k, v in os.environ.items():
89+
for e in env_vars:
90+
if any(fnmatch.fnmatch(k, e)):
91+
env_exports.append(f"{k}={v}")
92+
if len(env_exports) > 0:
93+
env_export_string = f"export {' '.join(env_exports)} && "
9194

9295
env_file_string = f"source {env_file} && " if env_file is not None else ""
9396

@@ -108,7 +111,7 @@ def launch(
108111

109112
log_dir = Path(log_dir)
110113
log_dir.mkdir(parents=True, exist_ok=True)
111-
timestamp = datetime.datetime.now().strftime("%y-%m-%d-%H%M%S")
114+
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
112115
agent_log_files = [log_dir / f"{timestamp}_{hostname}.log" for hostname in hostnames]
113116

114117
# start process to read from agent 0 log
@@ -136,11 +139,15 @@ def launch(
136139

137140
# build and sync payloads between launcher and agents
138141

139-
cumulative_workers = [0] + list(itertools.accumulate(workers_per_host))
140-
worker_world_size = cumulative_workers[-1]
141-
worker_global_ranks = [ # list of worker ranks per host
142-
list(range(cumulative_workers[n], cumulative_workers[n + 1])) for n in range(num_hosts)
143-
]
142+
_cumulative_workers = [0] + list(itertools.accumulate(workers_per_host))
143+
144+
worker_world_size = _cumulative_workers[-1]
145+
146+
worker_global_ranks = [] # list of worker ranks per host
147+
for n in range(num_hosts):
148+
host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1])
149+
worker_global_ranks.append(list(host_ranks))
150+
144151
worker_log_files = [
145152
[
146153
log_dir / f"{timestamp}_{hostname}_{local_rank}.log"
@@ -183,7 +190,7 @@ def launch(
183190
e += f"{v.message['extraInfo']['py_callstack']}\n\n"
184191
raise RuntimeError(e)
185192
except:
186-
# kill all agents
193+
# cleanup: SIGTERM all agents
187194
for agent_pid, agent_hostname in zip(agent_pids, hostnames):
188195
execute_command(
189196
command=f"kill {agent_pid}",

0 commit comments

Comments
 (0)