@@ -103,6 +103,13 @@ class Hyperparameters():
103103 # Parallel Residuals (Modification 5)
104104 parallel_start_layer = int (os .environ .get ("PARALLEL_START_LAYER" , "7" ))
105105
106+ # BigramHash (from #1019 pattern)
107+ bigram_vocab_size = int (os .environ .get ('BIGRAM_VOCAB_SIZE' , '3072' ))
108+ bigram_dim = int (os .environ .get ('BIGRAM_DIM' , '112' ))
109+
110+ # PROTEUS: Untied MLP for recurrence layers
111+ recur_untie_mlp = bool (int (os .environ .get ('RECUR_UNTIE_MLP' , '1' )))
112+
106113 # TTT (Modification 4)
107114 ttt_enabled = bool (int (os .environ .get ("TTT_ENABLED" , "0" )))
108115 ttt_lr = float (os .environ .get ("TTT_LR" , 0.002 ))
@@ -556,6 +563,40 @@ def forward(self, token_ids: Tensor) -> Tensor:
556563 return h * self .scale .to (dtype = h .dtype )
557564
558565
566+ class SmearGate (nn .Module ):
567+ def __init__ (self , dim : int ):
568+ super ().__init__ ()
569+ self .gate = nn .Parameter (torch .zeros (dim , dtype = torch .float32 ))
570+ def forward (self , x : Tensor ) -> Tensor :
571+ g = torch .sigmoid (self .gate .to (dtype = x .dtype ))[None , None , :]
572+ x_prev = torch .cat ([torch .zeros_like (x [:, :1 ]), x [:, :- 1 ]], dim = 1 )
573+ return (1 - g ) * x + g * x_prev
574+
575+
576+ class BigramHashEmbedding (nn .Module ):
577+ def __init__ (self , bigram_vocab_size : int , bigram_dim : int , model_dim : int ):
578+ super ().__init__ ()
579+ self .bigram_vocab_size = bigram_vocab_size
580+ self .embed = nn .Embedding (bigram_vocab_size , bigram_dim )
581+ nn .init .zeros_ (self .embed .weight )
582+ self .proj = CastedLinear (bigram_dim , model_dim , bias = False ) if bigram_dim != model_dim else None
583+ if self .proj is not None :
584+ nn .init .zeros_ (self .proj .weight )
585+ self .scale = nn .Parameter (torch .tensor (0.05 , dtype = torch .float32 ))
586+ def bigram_hash (self , tokens : Tensor ) -> Tensor :
587+ t = tokens .to (torch .int32 )
588+ mod = self .bigram_vocab_size - 1
589+ out = torch .empty_like (t )
590+ out [..., 0 ] = mod
591+ out [..., 1 :] = torch .bitwise_xor (36313 * t [..., 1 :], 27191 * t [..., :- 1 ]) % mod
592+ return out .long ()
593+ def forward (self , token_ids : Tensor ) -> Tensor :
594+ h = self .embed (self .bigram_hash (token_ids ))
595+ if self .proj is not None :
596+ h = self .proj (h )
597+ return h * self .scale .to (dtype = h .dtype )
598+
599+
559600class MLP (nn .Module ):
560601 def __init__ (self , dim : int , mlp_mult : int ):
561602 super ().__init__ ()
@@ -571,7 +612,7 @@ def forward(self, x: Tensor) -> Tensor:
571612class Block (nn .Module ):
572613 def __init__ (self , dim : int , num_heads : int , num_kv_heads : int , mlp_mult : int ,
573614 rope_base : float , qk_gain_init : float , train_seq_len : int ,
574- layer_idx : int = 0 , ln_scale : bool = False ):
615+ layer_idx : int = 0 , ln_scale : bool = False , parallel : bool = False ):
575616 super ().__init__ ()
576617 self .attn_norm = RMSNorm ()
577618 self .mlp_norm = RMSNorm ()
@@ -581,13 +622,18 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int,
581622 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
582623 self .resid_mix = nn .Parameter (torch .stack ((torch .ones (dim ), torch .zeros (dim ))).float ())
583624 self .ln_scale_factor = 1.0 / math .sqrt (layer_idx + 1 ) if ln_scale else 1.0
625+ self .parallel = parallel
626+ if parallel :
627+ self .resid_mix_mlp = nn .Parameter (torch .stack ((torch .ones (dim ), torch .zeros (dim ))).float ())
628+ self .route = nn .Parameter (torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ]))
584629
585- def forward (self , x : Tensor , x0 : Tensor , v_embed : Tensor | None = None ) -> Tensor :
630+ def forward (self , x : Tensor , x0 : Tensor , v_embed : Tensor | None = None , mlp_override : nn .Module | None = None ) -> Tensor :
631+ mlp_fn = mlp_override if mlp_override is not None else self .mlp
586632 mix = self .resid_mix .to (dtype = x .dtype )
587633 x_in = mix [0 ][None , None , :] * x + mix [1 ][None , None , :] * x0
588634 attn_out = self .attn (self .attn_norm (x_in ) * self .ln_scale_factor , v_embed = v_embed )
589635 x_out = x_in + self .attn_scale .to (dtype = x_in .dtype )[None , None , :] * attn_out
590- x_out = x_out + self .mlp_scale .to (dtype = x_out .dtype )[None , None , :] * self . mlp (self .mlp_norm (x_out ) * self .ln_scale_factor )
636+ x_out = x_out + self .mlp_scale .to (dtype = x_out .dtype )[None , None , :] * mlp_fn (self .mlp_norm (x_out ) * self .ln_scale_factor )
591637 return x_out
592638
593639
@@ -601,6 +647,8 @@ def __init__(self, h: Hyperparameters):
601647 self .tied_embed_init_std = h .tied_embed_init_std
602648 self .logit_softcap = h .logit_softcap
603649 self .tok_emb = nn .Embedding (h .vocab_size , h .embedding_dim )
650+ self .bigram = BigramHashEmbedding (h .bigram_vocab_size , h .bigram_dim , h .model_dim ) if h .bigram_vocab_size > 0 else None
651+ self .smear = SmearGate (h .model_dim ) if h .bigram_vocab_size > 0 else None
604652 if h .embedding_dim != h .model_dim :
605653 self .embed_proj = CastedLinear (h .embedding_dim , h .model_dim , bias = False )
606654 self .head_proj = CastedLinear (h .model_dim , h .embedding_dim , bias = False )
@@ -614,7 +662,8 @@ def __init__(self, h: Hyperparameters):
614662 self .skip_gates = nn .Parameter (torch .zeros (self .num_skip_weights , h .model_dim , dtype = torch .float32 )) if h .skip_gates_enabled else None
615663 self .blocks = nn .ModuleList ([
616664 Block (h .model_dim , h .num_heads , h .num_kv_heads , h .mlp_mult , h .rope_base ,
617- h .qk_gain_init , h .train_seq_len , layer_idx = i , ln_scale = h .ln_scale )
665+ h .qk_gain_init , h .train_seq_len , layer_idx = i , ln_scale = h .ln_scale ,
666+ parallel = (h .parallel_start_layer > 0 and i >= h .parallel_start_layer ))
618667 for i in range (h .num_layers )
619668 ])
620669 if h .rope_dims > 0 :
@@ -652,6 +701,17 @@ def __init__(self, h: Hyperparameters):
652701 else :
653702 self .lane_merge = None
654703
704+ # PROTEUS: Untied MLP weights for recurrence layers
705+ self .recur_untie_mlp = h .recur_untie_mlp
706+ self ._recur_layer_set = set (self .recur_layers )
707+ if self .recur_layers and h .recur_untie_mlp :
708+ self .recur_mlps = nn .ModuleDict ({
709+ str (lid ): MLP (h .model_dim , h .mlp_mult )
710+ for lid in self .recur_layers
711+ })
712+ else :
713+ self .recur_mlps = None
714+
655715 self ._init_weights ()
656716
657717 def set_recurrence_active (self , active : bool ) -> None :
@@ -697,11 +757,21 @@ def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = Non
697757 ve_idx = self .ve_layer_indices .index (layer_idx )
698758 return ve_base * self .ve_layer_scales [ve_idx ].to (dtype = ve_base .dtype )
699759
760+ def _get_recur_mlp (self , phys_idx : int ) -> nn .Module | None :
761+ """Return untied MLP for recurrence pass, or None for normal pass."""
762+ if self .recur_mlps is not None and str (phys_idx ) in self .recur_mlps :
763+ return self .recur_mlps [str (phys_idx )]
764+ return None
765+
700766 def forward_logits (self , input_ids : Tensor ) -> Tensor :
701767 x = self .tok_emb (input_ids )
768+ if self .bigram is not None :
769+ x = x + self .bigram (input_ids )
702770 x = F .rms_norm (x , (x .size (- 1 ),))
703771 if self .embed_proj is not None :
704772 x = self .embed_proj (x )
773+ if self .smear is not None :
774+ x = self .smear (x )
705775 x0 = x
706776
707777 virtual_layers = self ._get_virtual_layers ()
@@ -718,22 +788,36 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
718788 lane0 = None # attention lane
719789 lane1 = None # MLP lane
720790
791+ # Track which physical layers have been visited (for recurrence detection)
792+ visited_layers : set [int ] = set ()
793+
721794 # Encoder phase
722795 for vi in range (num_enc ):
723796 phys_idx = virtual_layers [vi ]
797+ is_recur_pass = phys_idx in visited_layers
724798 ve = self ._get_ve (phys_idx , input_ids , ve_cache )
725- x = self .blocks [phys_idx ](x , x0 , v_embed = ve )
799+ recur_mlp = self ._get_recur_mlp (phys_idx ) if is_recur_pass else None
800+ x = self .blocks [phys_idx ](x , x0 , v_embed = ve , mlp_override = recur_mlp )
801+ visited_layers .add (phys_idx )
726802 skips .append (x )
727803
728804 # Decoder phase with U-Net skip connections
729805 for vi in range (num_dec ):
730806 phys_idx = virtual_layers [num_enc + vi ]
807+ is_recur_pass = phys_idx in visited_layers
808+
731809 if skips and vi < self .num_skip_weights :
732- scaled_skip = self .skip_weights [vi ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
733- if self .skip_gates is not None :
810+ if is_parallel_mode :
811+ # In parallel mode, add skip to both lanes (per #1289 PROTEUS)
812+ scaled_skip = self .skip_weights [vi ].to (dtype = lane0 .dtype )[None , None , :] * skips .pop ()
813+ lane0 = lane0 + scaled_skip
814+ lane1 = lane1 + scaled_skip
815+ elif self .skip_gates is not None :
816+ scaled_skip = self .skip_weights [vi ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
734817 g = torch .sigmoid (self .skip_gates [vi ].to (dtype = x .dtype ))[None , None , :]
735818 x = torch .lerp (scaled_skip , x , g )
736819 else :
820+ scaled_skip = self .skip_weights [vi ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
737821 x = x + scaled_skip
738822
739823 # Check if we should enter parallel mode
@@ -745,20 +829,29 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
745829 if is_parallel_mode :
746830 block = self .blocks [phys_idx ]
747831 ve = self ._get_ve (phys_idx , input_ids , ve_cache )
832+ recur_mlp = self ._get_recur_mlp (phys_idx ) if is_recur_pass else None
833+ mlp_fn = recur_mlp if recur_mlp is not None else block .mlp
748834
749- # Attention operates on lane0
750- mix = block .resid_mix .to (dtype = lane0 .dtype )
751- attn_in = mix [0 ][None , None , :] * lane0 + mix [1 ][None , None , :] * x0
835+ # PROTEUS routing: separate residual mixing for each lane
836+ mix_attn = block .resid_mix .to (dtype = lane0 .dtype )
837+ attn_in = mix_attn [0 ][None , None , :] * lane0 + mix_attn [1 ][None , None , :] * x0
752838 attn_out = block .attn (block .attn_norm (attn_in ) * block .ln_scale_factor , v_embed = ve )
753- lane0 = attn_in + block .attn_scale .to (dtype = attn_in .dtype )[None , None , :] * attn_out
839+ attn_delta = block .attn_scale .to (dtype = attn_in .dtype )[None , None , :] * attn_out
840+
841+ mix_mlp = block .resid_mix_mlp .to (dtype = lane1 .dtype )
842+ mlp_in = mix_mlp [0 ][None , None , :] * lane1 + mix_mlp [1 ][None , None , :] * x0
843+ mlp_delta = block .mlp_scale .to (dtype = lane1 .dtype )[None , None , :] * mlp_fn (block .mlp_norm (mlp_in ) * block .ln_scale_factor )
754844
755- # MLP operates on lane1
756- mlp_in = block .mlp_norm ( lane1 ) * block . ln_scale_factor
757- mlp_out = block . mlp ( mlp_in )
758- lane1 = lane1 + block . mlp_scale . to ( dtype = lane1 . dtype )[ None , None , : ] * mlp_out
845+ # 4-way routing
846+ r = block .route . to ( dtype = lane0 . dtype )
847+ lane0 = lane0 + r [ 0 ] * attn_delta + r [ 2 ] * mlp_delta
848+ lane1 = lane1 + r [ 1 ] * attn_delta + r [ 3 ] * mlp_delta
759849 else :
760850 ve = self ._get_ve (phys_idx , input_ids , ve_cache )
761- x = self .blocks [phys_idx ](x , x0 , v_embed = ve )
851+ recur_mlp = self ._get_recur_mlp (phys_idx ) if is_recur_pass else None
852+ x = self .blocks [phys_idx ](x , x0 , v_embed = ve , mlp_override = recur_mlp )
853+
854+ visited_layers .add (phys_idx )
762855
763856 # Merge parallel lanes if active
764857 if is_parallel_mode :
@@ -783,7 +876,9 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
783876def classify_param (name : str ) -> str :
784877 if "tok_emb" in name or "lm_head" in name :
785878 return "embed"
786- if ".mlp." in name :
879+ if "bigram" in name or "smear" in name :
880+ return "embed"
881+ if "recur_mlps" in name or ".mlp." in name :
787882 return "mlp"
788883 if ".attn." in name or (".proj." in name and ".mlp." not in name ):
789884 return "attn"
@@ -890,8 +985,24 @@ def __init__(self, h: Hyperparameters, base_model: GPT):
890985 if base_model .lane_merge is not None :
891986 scalar_params .append (base_model .lane_merge )
892987
988+ # PROTEUS: recur_untie_mlp params
989+ if base_model .recur_mlps is not None :
990+ for name , p in base_model .recur_mlps .named_parameters ():
991+ if p .ndim == 2 and not any (pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS ):
992+ matrix_params .append (p )
993+ else :
994+ scalar_params .append (p )
995+
893996 token_lr = h .tied_embed_lr if h .tie_embeddings else h .embed_lr
894997 tok_params = [{"params" : [base_model .tok_emb .weight ], "lr" : token_lr , "base_lr" : token_lr }]
998+ # BigramHash params
999+ if base_model .bigram is not None :
1000+ tok_params .append ({"params" : [base_model .bigram .embed .weight ], "lr" : token_lr , "base_lr" : token_lr })
1001+ if base_model .bigram .proj is not None :
1002+ matrix_params .append (base_model .bigram .proj .weight )
1003+ scalar_params .append (base_model .bigram .scale )
1004+ if base_model .smear is not None :
1005+ scalar_params .append (base_model .smear .gate )
8951006 if base_model .ve_shared is not None :
8961007 tok_params .append ({"params" : [base_model .ve_shared .embed .weight ], "lr" : token_lr , "base_lr" : token_lr })
8971008 if base_model .ve_shared .proj is not None :
@@ -955,7 +1066,7 @@ def step(self):
9551066 pattern
9561067 for pattern in os .environ .get (
9571068 "CONTROL_TENSOR_NAME_PATTERNS" ,
958- "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale,lane_merge" ,
1069+ "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,resid_mix_mlp,route, q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale,lane_merge,bigram.scale,smear.gate " ,
9591070 ).split ("," )
9601071 if pattern
9611072)
0 commit comments