From 28d41a448c462c2d8092e769131c7b72306d917f Mon Sep 17 00:00:00 2001 From: Vertigo Date: Wed, 18 Mar 2026 23:17:11 +0300 Subject: [PATCH] feat: recursive weight sharing for 16MB limit --- train_gpt.py | 69 ++++++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f5..7c57c3886a 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -647,45 +647,43 @@ def forward(self, x: Tensor, x0: Tensor) -> Tensor: class GPT(nn.Module): def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, ): super().__init__() if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.num_layers = num_layers # Количество рекурсивных шагов + self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] + + # Вместо списка блоков создаем ОДИН общий блок для всех слоев + self.shared_block = Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, ) + self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() @@ -701,25 +699,22 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x - skips: list[Tensor] = [] - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + # Рекурсивно пропускаем через один и тот же блок + # Это экономит массу места в итоговом файле 16MB + for _ in range(self.num_layers): + x = self.shared_block(x, x0) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) + if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) return F.cross_entropy(logits.float(), targets, reduction="mean")