diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f5..9a3393a2de 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -514,10 +514,14 @@ def forward(self, x: Tensor) -> Tensor: def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. + # Keep the high-leverage tied embedding plus small/control parameters in fp32. with torch.no_grad(): for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + if ( + name == "tok_emb.weight" + or param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: param.data = param.data.float() @@ -698,7 +702,7 @@ def _init_weights(self) -> None: nn.init.zeros_(module.weight) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) + x = F.embedding(input_ids, self.tok_emb.weight).to(dtype=torch.bfloat16) x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] @@ -715,7 +719,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) + logits_proj = F.linear(x, self.tok_emb.weight.to(dtype=x.dtype)) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index bf7c7d1b8c..2ecbbcd994 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -405,7 +405,7 @@ def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, n b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) self.tok_emb.weight = ( mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) + ) def softcap(self, logits: mx.array) -> mx.array: c = self.logit_softcap