Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,11 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
raise RuntimeError("lm_head is required when tie_embeddings=False")
logits_proj = self.lm_head(x)
logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
# BPB-weighted loss: weight tokens by byte count to align with eval metric
if self.training and hasattr(self, '_byte_weights'):
per_token_loss = F.cross_entropy(logits.float(), targets, reduction="none")
w = self._byte_weights[targets]
return (per_token_loss * w).sum() / w.sum()
return F.cross_entropy(logits.float(), targets, reduction="mean")


Expand Down Expand Up @@ -836,6 +841,9 @@ def log0(msg: str, console: bool = True) -> None:
rope_base=args.rope_base,
qk_gain_init=args.qk_gain_init,
).to(device).bfloat16()
# BPB-weighted loss: register byte count per token as buffer
base_model.register_buffer('_byte_weights', base_bytes_lut.float().clamp(min=1.0))
log0(f"bpb_weighted_loss:enabled mean_bytes_per_token:{base_bytes_lut.float().clamp(min=1.0).mean():.2f}")
for module in base_model.modules():
if isinstance(module, CastedLinear):
module.float()
Expand Down Expand Up @@ -1065,15 +1073,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
# Save the raw state (useful for debugging/loading in PyTorch directly), then always produce
# the compressed int8+zlib artifact and validate the round-tripped weights.

# Filter training-only buffers from saved state
save_sd = {k: v for k, v in base_model.state_dict().items() if k != '_byte_weights'}
if master_process:
torch.save(base_model.state_dict(), "final_model.pt")
torch.save(save_sd, "final_model.pt")
model_bytes = os.path.getsize("final_model.pt")
code_bytes = len(code.encode("utf-8"))
log0(f"Serialized model: {model_bytes} bytes")
log0(f"Code size: {code_bytes} bytes")
log0(f"Total submission size: {model_bytes + code_bytes} bytes")

quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict())
quant_obj, quant_stats = quantize_state_dict_int8(save_sd)
quant_buf = io.BytesIO()
torch.save(quant_obj, quant_buf)
quant_raw = quant_buf.getvalue()
Expand All @@ -1096,7 +1106,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
with open("final_model.int8.ptz", "rb") as f:
quant_blob_disk = f.read()
quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu")
base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True)
base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=False)
torch.cuda.synchronize()
t_qeval = time.perf_counter()
q_val_loss, q_val_bpb = eval_val(
Expand Down