Skip to content

Commit 63fcfc2

Browse files
committed
Add BigramHash + PROTEUS routing + recur_untie_mlp
- BigramHash 3072x112 (from openai#1019 pattern, ~-0.001 BPP) - SmearGate for temporal smoothing - PROTEUS 4-way routing: resid_mix_mlp + route params for parallel blocks - Skip connections apply to both lanes in parallel mode (per openai#1289) - Untied MLP weights for recurrence layers (separate from block MLPs) - Fix: skip connections in parallel mode now correctly update both lanes
1 parent 3de320a commit 63fcfc2

1 file changed

Lines changed: 129 additions & 18 deletions

File tree

train_gpt.py

Lines changed: 129 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class Hyperparameters():
103103
# Parallel Residuals (Modification 5)
104104
parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", "7"))
105105

106+
# BigramHash (from #1019 pattern)
107+
bigram_vocab_size = int(os.environ.get('BIGRAM_VOCAB_SIZE', '3072'))
108+
bigram_dim = int(os.environ.get('BIGRAM_DIM', '112'))
109+
110+
# PROTEUS: Untied MLP for recurrence layers
111+
recur_untie_mlp = bool(int(os.environ.get('RECUR_UNTIE_MLP', '1')))
112+
106113
# TTT (Modification 4)
107114
ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0")))
108115
ttt_lr = float(os.environ.get("TTT_LR", 0.002))
@@ -556,6 +563,40 @@ def forward(self, token_ids: Tensor) -> Tensor:
556563
return h * self.scale.to(dtype=h.dtype)
557564

558565

566+
class SmearGate(nn.Module):
567+
def __init__(self, dim: int):
568+
super().__init__()
569+
self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
570+
def forward(self, x: Tensor) -> Tensor:
571+
g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :]
572+
x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1)
573+
return (1 - g) * x + g * x_prev
574+
575+
576+
class BigramHashEmbedding(nn.Module):
577+
def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int):
578+
super().__init__()
579+
self.bigram_vocab_size = bigram_vocab_size
580+
self.embed = nn.Embedding(bigram_vocab_size, bigram_dim)
581+
nn.init.zeros_(self.embed.weight)
582+
self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None
583+
if self.proj is not None:
584+
nn.init.zeros_(self.proj.weight)
585+
self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32))
586+
def bigram_hash(self, tokens: Tensor) -> Tensor:
587+
t = tokens.to(torch.int32)
588+
mod = self.bigram_vocab_size - 1
589+
out = torch.empty_like(t)
590+
out[..., 0] = mod
591+
out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod
592+
return out.long()
593+
def forward(self, token_ids: Tensor) -> Tensor:
594+
h = self.embed(self.bigram_hash(token_ids))
595+
if self.proj is not None:
596+
h = self.proj(h)
597+
return h * self.scale.to(dtype=h.dtype)
598+
599+
559600
class MLP(nn.Module):
560601
def __init__(self, dim: int, mlp_mult: int):
561602
super().__init__()
@@ -571,7 +612,7 @@ def forward(self, x: Tensor) -> Tensor:
571612
class Block(nn.Module):
572613
def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int,
573614
rope_base: float, qk_gain_init: float, train_seq_len: int,
574-
layer_idx: int = 0, ln_scale: bool = False):
615+
layer_idx: int = 0, ln_scale: bool = False, parallel: bool = False):
575616
super().__init__()
576617
self.attn_norm = RMSNorm()
577618
self.mlp_norm = RMSNorm()
@@ -581,13 +622,18 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int,
581622
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
582623
self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
583624
self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0
625+
self.parallel = parallel
626+
if parallel:
627+
self.resid_mix_mlp = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
628+
self.route = nn.Parameter(torch.tensor([1.0, 1.0, 1.0, 1.0]))
584629

585-
def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor:
630+
def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, mlp_override: nn.Module | None = None) -> Tensor:
631+
mlp_fn = mlp_override if mlp_override is not None else self.mlp
586632
mix = self.resid_mix.to(dtype=x.dtype)
587633
x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
588634
attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed)
589635
x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out
590-
x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor)
636+
x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_fn(self.mlp_norm(x_out) * self.ln_scale_factor)
591637
return x_out
592638

593639

@@ -601,6 +647,8 @@ def __init__(self, h: Hyperparameters):
601647
self.tied_embed_init_std = h.tied_embed_init_std
602648
self.logit_softcap = h.logit_softcap
603649
self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim)
650+
self.bigram = BigramHashEmbedding(h.bigram_vocab_size, h.bigram_dim, h.model_dim) if h.bigram_vocab_size > 0 else None
651+
self.smear = SmearGate(h.model_dim) if h.bigram_vocab_size > 0 else None
604652
if h.embedding_dim != h.model_dim:
605653
self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False)
606654
self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False)
@@ -614,7 +662,8 @@ def __init__(self, h: Hyperparameters):
614662
self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None
615663
self.blocks = nn.ModuleList([
616664
Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base,
617-
h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale)
665+
h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale,
666+
parallel=(h.parallel_start_layer > 0 and i >= h.parallel_start_layer))
618667
for i in range(h.num_layers)
619668
])
620669
if h.rope_dims > 0:
@@ -652,6 +701,17 @@ def __init__(self, h: Hyperparameters):
652701
else:
653702
self.lane_merge = None
654703

704+
# PROTEUS: Untied MLP weights for recurrence layers
705+
self.recur_untie_mlp = h.recur_untie_mlp
706+
self._recur_layer_set = set(self.recur_layers)
707+
if self.recur_layers and h.recur_untie_mlp:
708+
self.recur_mlps = nn.ModuleDict({
709+
str(lid): MLP(h.model_dim, h.mlp_mult)
710+
for lid in self.recur_layers
711+
})
712+
else:
713+
self.recur_mlps = None
714+
655715
self._init_weights()
656716

657717
def set_recurrence_active(self, active: bool) -> None:
@@ -697,11 +757,21 @@ def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = Non
697757
ve_idx = self.ve_layer_indices.index(layer_idx)
698758
return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype)
699759

760+
def _get_recur_mlp(self, phys_idx: int) -> nn.Module | None:
761+
"""Return untied MLP for recurrence pass, or None for normal pass."""
762+
if self.recur_mlps is not None and str(phys_idx) in self.recur_mlps:
763+
return self.recur_mlps[str(phys_idx)]
764+
return None
765+
700766
def forward_logits(self, input_ids: Tensor) -> Tensor:
701767
x = self.tok_emb(input_ids)
768+
if self.bigram is not None:
769+
x = x + self.bigram(input_ids)
702770
x = F.rms_norm(x, (x.size(-1),))
703771
if self.embed_proj is not None:
704772
x = self.embed_proj(x)
773+
if self.smear is not None:
774+
x = self.smear(x)
705775
x0 = x
706776

707777
virtual_layers = self._get_virtual_layers()
@@ -718,22 +788,36 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
718788
lane0 = None # attention lane
719789
lane1 = None # MLP lane
720790

791+
# Track which physical layers have been visited (for recurrence detection)
792+
visited_layers: set[int] = set()
793+
721794
# Encoder phase
722795
for vi in range(num_enc):
723796
phys_idx = virtual_layers[vi]
797+
is_recur_pass = phys_idx in visited_layers
724798
ve = self._get_ve(phys_idx, input_ids, ve_cache)
725-
x = self.blocks[phys_idx](x, x0, v_embed=ve)
799+
recur_mlp = self._get_recur_mlp(phys_idx) if is_recur_pass else None
800+
x = self.blocks[phys_idx](x, x0, v_embed=ve, mlp_override=recur_mlp)
801+
visited_layers.add(phys_idx)
726802
skips.append(x)
727803

728804
# Decoder phase with U-Net skip connections
729805
for vi in range(num_dec):
730806
phys_idx = virtual_layers[num_enc + vi]
807+
is_recur_pass = phys_idx in visited_layers
808+
731809
if skips and vi < self.num_skip_weights:
732-
scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop()
733-
if self.skip_gates is not None:
810+
if is_parallel_mode:
811+
# In parallel mode, add skip to both lanes (per #1289 PROTEUS)
812+
scaled_skip = self.skip_weights[vi].to(dtype=lane0.dtype)[None, None, :] * skips.pop()
813+
lane0 = lane0 + scaled_skip
814+
lane1 = lane1 + scaled_skip
815+
elif self.skip_gates is not None:
816+
scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop()
734817
g = torch.sigmoid(self.skip_gates[vi].to(dtype=x.dtype))[None, None, :]
735818
x = torch.lerp(scaled_skip, x, g)
736819
else:
820+
scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop()
737821
x = x + scaled_skip
738822

739823
# Check if we should enter parallel mode
@@ -745,20 +829,29 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
745829
if is_parallel_mode:
746830
block = self.blocks[phys_idx]
747831
ve = self._get_ve(phys_idx, input_ids, ve_cache)
832+
recur_mlp = self._get_recur_mlp(phys_idx) if is_recur_pass else None
833+
mlp_fn = recur_mlp if recur_mlp is not None else block.mlp
748834

749-
# Attention operates on lane0
750-
mix = block.resid_mix.to(dtype=lane0.dtype)
751-
attn_in = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0
835+
# PROTEUS routing: separate residual mixing for each lane
836+
mix_attn = block.resid_mix.to(dtype=lane0.dtype)
837+
attn_in = mix_attn[0][None, None, :] * lane0 + mix_attn[1][None, None, :] * x0
752838
attn_out = block.attn(block.attn_norm(attn_in) * block.ln_scale_factor, v_embed=ve)
753-
lane0 = attn_in + block.attn_scale.to(dtype=attn_in.dtype)[None, None, :] * attn_out
839+
attn_delta = block.attn_scale.to(dtype=attn_in.dtype)[None, None, :] * attn_out
840+
841+
mix_mlp = block.resid_mix_mlp.to(dtype=lane1.dtype)
842+
mlp_in = mix_mlp[0][None, None, :] * lane1 + mix_mlp[1][None, None, :] * x0
843+
mlp_delta = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_fn(block.mlp_norm(mlp_in) * block.ln_scale_factor)
754844

755-
# MLP operates on lane1
756-
mlp_in = block.mlp_norm(lane1) * block.ln_scale_factor
757-
mlp_out = block.mlp(mlp_in)
758-
lane1 = lane1 + block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out
845+
# 4-way routing
846+
r = block.route.to(dtype=lane0.dtype)
847+
lane0 = lane0 + r[0] * attn_delta + r[2] * mlp_delta
848+
lane1 = lane1 + r[1] * attn_delta + r[3] * mlp_delta
759849
else:
760850
ve = self._get_ve(phys_idx, input_ids, ve_cache)
761-
x = self.blocks[phys_idx](x, x0, v_embed=ve)
851+
recur_mlp = self._get_recur_mlp(phys_idx) if is_recur_pass else None
852+
x = self.blocks[phys_idx](x, x0, v_embed=ve, mlp_override=recur_mlp)
853+
854+
visited_layers.add(phys_idx)
762855

763856
# Merge parallel lanes if active
764857
if is_parallel_mode:
@@ -783,7 +876,9 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
783876
def classify_param(name: str) -> str:
784877
if "tok_emb" in name or "lm_head" in name:
785878
return "embed"
786-
if ".mlp." in name:
879+
if "bigram" in name or "smear" in name:
880+
return "embed"
881+
if "recur_mlps" in name or ".mlp." in name:
787882
return "mlp"
788883
if ".attn." in name or (".proj." in name and ".mlp." not in name):
789884
return "attn"
@@ -890,8 +985,24 @@ def __init__(self, h: Hyperparameters, base_model: GPT):
890985
if base_model.lane_merge is not None:
891986
scalar_params.append(base_model.lane_merge)
892987

988+
# PROTEUS: recur_untie_mlp params
989+
if base_model.recur_mlps is not None:
990+
for name, p in base_model.recur_mlps.named_parameters():
991+
if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS):
992+
matrix_params.append(p)
993+
else:
994+
scalar_params.append(p)
995+
893996
token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr
894997
tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}]
998+
# BigramHash params
999+
if base_model.bigram is not None:
1000+
tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr})
1001+
if base_model.bigram.proj is not None:
1002+
matrix_params.append(base_model.bigram.proj.weight)
1003+
scalar_params.append(base_model.bigram.scale)
1004+
if base_model.smear is not None:
1005+
scalar_params.append(base_model.smear.gate)
8951006
if base_model.ve_shared is not None:
8961007
tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr})
8971008
if base_model.ve_shared.proj is not None:
@@ -955,7 +1066,7 @@ def step(self):
9551066
pattern
9561067
for pattern in os.environ.get(
9571068
"CONTROL_TENSOR_NAME_PATTERNS",
958-
"attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale,lane_merge",
1069+
"attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,resid_mix_mlp,route,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale,lane_merge,bigram.scale,smear.gate",
9591070
).split(",")
9601071
if pattern
9611072
)

0 commit comments

Comments
 (0)