@@ -83,11 +83,14 @@ def launch(
83
83
84
84
# launch command
85
85
86
- env_export_string = " " .join (
87
- f'{ k } ="{ v } "' for k , v in os .environ .items () if any (fnmatch .fnmatch (k , e ) for e in env_vars )
88
- )
89
- if env_export_string != "" :
90
- env_export_string = f"export { env_export_string } && "
86
+ env_export_string = ""
87
+ env_exports = []
88
+ for k , v in os .environ .items ():
89
+ for e in env_vars :
90
+ if any (fnmatch .fnmatch (k , e )):
91
+ env_exports .append (f"{ k } ={ v } " )
92
+ if len (env_exports ) > 0 :
93
+ env_export_string = f"export { ' ' .join (env_exports )} && "
91
94
92
95
env_file_string = f"source { env_file } && " if env_file is not None else ""
93
96
@@ -108,7 +111,7 @@ def launch(
108
111
109
112
log_dir = Path (log_dir )
110
113
log_dir .mkdir (parents = True , exist_ok = True )
111
- timestamp = datetime .datetime .now ().strftime ( "%y-%m-%d-%H%M%S " )
114
+ timestamp = datetime .datetime .now ().isoformat ( timespec = "seconds " )
112
115
agent_log_files = [log_dir / f"{ timestamp } _{ hostname } .log" for hostname in hostnames ]
113
116
114
117
# start process to read from agent 0 log
@@ -136,11 +139,15 @@ def launch(
136
139
137
140
# build and sync payloads between launcher and agents
138
141
139
- cumulative_workers = [0 ] + list (itertools .accumulate (workers_per_host ))
140
- worker_world_size = cumulative_workers [- 1 ]
141
- worker_global_ranks = [ # list of worker ranks per host
142
- list (range (cumulative_workers [n ], cumulative_workers [n + 1 ])) for n in range (num_hosts )
143
- ]
142
+ _cumulative_workers = [0 ] + list (itertools .accumulate (workers_per_host ))
143
+
144
+ worker_world_size = _cumulative_workers [- 1 ]
145
+
146
+ worker_global_ranks = [] # list of worker ranks per host
147
+ for n in range (num_hosts ):
148
+ host_ranks = range (_cumulative_workers [n ], _cumulative_workers [n + 1 ])
149
+ worker_global_ranks .append (list (host_ranks ))
150
+
144
151
worker_log_files = [
145
152
[
146
153
log_dir / f"{ timestamp } _{ hostname } _{ local_rank } .log"
@@ -183,7 +190,7 @@ def launch(
183
190
e += f"{ v .message ['extraInfo' ]['py_callstack' ]} \n \n "
184
191
raise RuntimeError (e )
185
192
except :
186
- # kill all agents
193
+ # cleanup: SIGTERM all agents
187
194
for agent_pid , agent_hostname in zip (agent_pids , hostnames ):
188
195
execute_command (
189
196
command = f"kill { agent_pid } " ,
0 commit comments