Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions train_gpt_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class Hyperparameters:
# Chunk each logical MLX microbatch into smaller sub-batches to reduce peak
# memory pressure without changing the effective optimizer batch.
mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192))
# Force MLX to materialize the graph after every sub-batch, preventing lazy
# graph buildup across accumulation steps. Keeps peak memory low on 16GB machines.
# Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0).
mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1")))
warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20))
warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200))
max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
Expand Down Expand Up @@ -749,6 +753,8 @@ def loss_and_grad_chunked(
scale = float(y.size) / total_tokens
loss_value = loss_value + loss.astype(mx.float32) * scale
grad_accum = accumulate_flat_grads(grad_accum, grads, scale)
if args.mlx_eager_eval:
mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory
return loss_value, tree_unflatten(list(grad_accum.items()))


Expand Down Expand Up @@ -1029,6 +1035,8 @@ def log(msg: str, console: bool = True) -> None:
loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
accum = accumulate_flat_grads(accum, grads, grad_scale)
train_loss = train_loss + loss.astype(mx.float32) * grad_scale
if args.mlx_eager_eval:
mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory

grads = tree_unflatten(list(accum.items()))
grads = clip_grad_tree(grads, args.grad_clip_norm)
Expand Down