Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions train_gpt_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def eval_val(
base_bytes_lut: np.ndarray,
has_leading_space_lut: np.ndarray,
is_boundary_token_lut: np.ndarray,
log_fn: Callable[[str], None] | None = None,
) -> tuple[float, float]:
# Validation computes two metrics:
# - val_loss: token cross-entropy (natural log)
Expand All @@ -772,10 +773,11 @@ def eval_val(
)
val_batch_seqs = val_batch_tokens // args.train_seq_len
total_seqs = (val_tokens.size - 1) // args.train_seq_len
total_loss = mx.array(0.0, dtype=mx.float32)
total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1)
total_loss_sum = 0.0
total_tokens = 0.0
total_bytes = 0.0
for batch_seq_start in range(0, total_seqs, val_batch_seqs):
for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1):
batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs)
raw_start = batch_seq_start * args.train_seq_len
raw_end = batch_seq_end * args.train_seq_len + 1
Expand All @@ -785,7 +787,9 @@ def eval_val(
x = mx.array(x_np, dtype=mx.int32)
y = mx.array(y_np, dtype=mx.int32)
chunk_token_count = float(y.size)
total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count
batch_loss = compiled_loss(x, y).astype(mx.float32)
mx.eval(batch_loss)
total_loss_sum += float(batch_loss.item()) * chunk_token_count
prev_ids = x_np.reshape(-1)
tgt_ids = y_np.reshape(-1)
bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True)
Expand All @@ -794,9 +798,11 @@ def eval_val(
).astype(np.int16, copy=False)
total_tokens += chunk_token_count
total_bytes += float(bytes_np.astype(np.float64).sum())
total_loss = total_loss / total_tokens
mx.eval(total_loss)
val_loss = float(total_loss.item())
if log_fn is not None and total_batches > 1 and (
batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0
):
log_fn(f"val_progress:{batch_idx}/{total_batches}")
val_loss = total_loss_sum / total_tokens
bits_per_token = val_loss / math.log(2.0)
val_bpb = bits_per_token * (total_tokens / total_bytes)
return val_loss, val_bpb
Expand Down Expand Up @@ -999,6 +1005,7 @@ def log(msg: str, console: bool = True) -> None:
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
log_fn=log,
)
train_time_ms += 1000.0 * (time.perf_counter() - t0)
if step % 25 == 0 or last_step:
Expand Down Expand Up @@ -1078,6 +1085,7 @@ def log(msg: str, console: bool = True) -> None:
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
log_fn=log,
)
q_eval_ms = 1000.0 * (time.perf_counter() - q_t0)
log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms")
Expand Down