Skip to content

Commit 6164dbf

Browse files
authored
Merge pull request #134 from instructlab/logging-updates
Logging updates
2 parents 2b744af + 07c7f63 commit 6164dbf

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

src/instructlab/training/main_ds.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,8 @@ def main(args):
476476
import yaml
477477

478478
metric_logger = AsyncStructuredLogger(
479-
args.output_dir + "/training_params_and_metrics.jsonl"
479+
args.output_dir
480+
+ f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl"
480481
)
481482
if os.environ["LOCAL_RANK"] == "0":
482483
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")
@@ -658,7 +659,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
658659
print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
659660
process = None
660661
try:
661-
process = StreamablePopen(command)
662+
process = StreamablePopen(
663+
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
664+
command,
665+
)
662666

663667
except KeyboardInterrupt:
664668
print("Process interrupted by user")

src/instructlab/training/utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,23 @@ class StreamablePopen(subprocess.Popen):
8888
Provides a way of reading stdout and stderr line by line.
8989
"""
9090

91-
def __init__(self, *args, **kwargs):
91+
def __init__(self, output_file, *args, **kwargs):
9292
# remove the stderr and stdout from kwargs
9393
kwargs.pop("stderr", None)
9494
kwargs.pop("stdout", None)
9595

96-
super().__init__(*args, **kwargs)
97-
while True:
98-
if self.stdout:
99-
output = self.stdout.readline().strip()
100-
print(output)
101-
if self.stderr:
102-
error = self.stderr.readline().strip()
103-
print(error, file=sys.stderr)
104-
if self.poll() is not None:
105-
break
96+
super().__init__(
97+
*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs
98+
)
99+
with open(output_file, "wb") as full_log_file:
100+
while True:
101+
byte = self.stdout.read(1)
102+
if byte:
103+
sys.stdout.buffer.write(byte)
104+
sys.stdout.flush()
105+
full_log_file.write(byte)
106+
else:
107+
break
106108

107109

108110
def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):

0 commit comments

Comments
 (0)