Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,45 +647,43 @@ def forward(self, x: Tensor, x0: Tensor) -> Tensor:

class GPT(nn.Module):
def __init__(
self,
vocab_size: int,
num_layers: int,
model_dim: int,
num_heads: int,
num_kv_heads: int,
mlp_mult: int,
tie_embeddings: bool,
tied_embed_init_std: float,
logit_softcap: float,
rope_base: float,
qk_gain_init: float,
self,
vocab_size: int,
num_layers: int,
model_dim: int,
num_heads: int,
num_kv_heads: int,
mlp_mult: int,
tie_embeddings: bool,
tied_embed_init_std: float,
logit_softcap: float,
rope_base: float,
qk_gain_init: float,
):
super().__init__()
if logit_softcap <= 0.0:
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")

self.tie_embeddings = tie_embeddings
self.tied_embed_init_std = tied_embed_init_std
self.logit_softcap = logit_softcap
self.num_layers = num_layers # Количество рекурсивных шагов

self.tok_emb = nn.Embedding(vocab_size, model_dim)
self.num_encoder_layers = num_layers // 2
self.num_decoder_layers = num_layers - self.num_encoder_layers
self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers)
self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32))
self.blocks = nn.ModuleList(
[
Block(
model_dim,
num_heads,
num_kv_heads,
mlp_mult,
rope_base,
qk_gain_init,
)
for i in range(num_layers)
]

# Вместо списка блоков создаем ОДИН общий блок для всех слоев
self.shared_block = Block(
model_dim,
num_heads,
num_kv_heads,
mlp_mult,
rope_base,
qk_gain_init,
)

self.final_norm = RMSNorm()
self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False)

if self.lm_head is not None:
self.lm_head._zero_init = True
self._init_weights()
Expand All @@ -701,25 +699,22 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
x = self.tok_emb(input_ids)
x = F.rms_norm(x, (x.size(-1),))
x0 = x
skips: list[Tensor] = []

# First half stores skips; second half reuses them in reverse order.
for i in range(self.num_encoder_layers):
x = self.blocks[i](x, x0)
skips.append(x)
for i in range(self.num_decoder_layers):
if skips:
x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
x = self.blocks[self.num_encoder_layers + i](x, x0)
# Рекурсивно пропускаем через один и тот же блок
# Это экономит массу места в итоговом файле 16MB
for _ in range(self.num_layers):
x = self.shared_block(x, x0)

x = self.final_norm(x).reshape(-1, x.size(-1))
targets = target_ids.reshape(-1)

if self.tie_embeddings:
logits_proj = F.linear(x, self.tok_emb.weight)
else:
if self.lm_head is None:
raise RuntimeError("lm_head is required when tie_embeddings=False")
logits_proj = self.lm_head(x)

logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
return F.cross_entropy(logits.float(), targets, reduction="mean")

Expand Down