From 0000334ffb9a248e25867c64af1e629c5bcdd27e Mon Sep 17 00:00:00 2001 From: definenoob Date: Thu, 9 Apr 2026 22:33:18 -0400 Subject: [PATCH] init --- train_gpt.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..ddf81977b1 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -721,6 +721,11 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + # BPB-weighted loss: weight tokens by byte count to align with eval metric + if self.training and hasattr(self, '_byte_weights'): + per_token_loss = F.cross_entropy(logits.float(), targets, reduction="none") + w = self._byte_weights[targets] + return (per_token_loss * w).sum() / w.sum() return F.cross_entropy(logits.float(), targets, reduction="mean") @@ -836,6 +841,9 @@ def log0(msg: str, console: bool = True) -> None: rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() + # BPB-weighted loss: register byte count per token as buffer + base_model.register_buffer('_byte_weights', base_bytes_lut.float().clamp(min=1.0)) + log0(f"bpb_weighted_loss:enabled mean_bytes_per_token:{base_bytes_lut.float().clamp(min=1.0).mean():.2f}") for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() @@ -1065,15 +1073,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce # the compressed int8+zlib artifact and validate the round-tripped weights. + # Filter training-only buffers from saved state + save_sd = {k: v for k, v in base_model.state_dict().items() if k != '_byte_weights'} if master_process: - torch.save(base_model.state_dict(), "final_model.pt") + torch.save(save_sd, "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_obj, quant_stats = quantize_state_dict_int8(save_sd) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() @@ -1096,7 +1106,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=False) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val(