@@ -108,8 +108,11 @@ class Hyperparameters:
108108 doc_isolated_eval = bool (int (os .environ .get ("DOC_ISOLATED_EVAL" , "1" ))) # eval per-document, no cross-doc context
109109 smear_gate = bool (int (os .environ .get ("SMEAR_GATE" , "1" ))) # cheap bigram context at embedding layer
110110 bigram_hash = bool (int (os .environ .get ("BIGRAM_HASH" , "1" ))) # hash-based bigram embedding
111- bigram_hash_buckets = int (os .environ .get ("BIGRAM_HASH_BUCKETS" , 2048 ))
111+ bigram_hash_buckets = int (os .environ .get ("BIGRAM_HASH_BUCKETS" , 4096 )) # doubled from 2048
112112 bigram_hash_dim = int (os .environ .get ("BIGRAM_HASH_DIM" , 128 ))
113+ trigram_hash = bool (int (os .environ .get ("TRIGRAM_HASH" , "1" ))) # 3-token context hash embedding
114+ trigram_hash_buckets = int (os .environ .get ("TRIGRAM_HASH_BUCKETS" , 4096 ))
115+ trigram_hash_dim = int (os .environ .get ("TRIGRAM_HASH_DIM" , 32 )) # smaller dim than bigram (additive)
113116 swa = bool (int (os .environ .get ("SWA" , "0" ))) # stochastic weight averaging (disabled: EMA preferred)
114117 swa_start_frac = float (os .environ .get ("SWA_START_FRAC" , 0.5 ))
115118 xsa_last_n = int (os .environ .get ("XSA_LAST_N" , 4 )) # XSA on last N layers (0=disabled)
@@ -181,6 +184,7 @@ class Hyperparameters:
181184 gated_attention = bool (int (os .environ .get ("GATED_ATTENTION" , "1" ))) # per-head sigmoid gate after SDPA
182185 # Per-layer lr: MLP proj (high quant damage) gets higher lr, MLP fc (low damage) gets lower lr
183186 # Based on our 34-config ablation showing 3.4x damage ratio between proj and fc weights
187+ star_relu = bool (int (os .environ .get ("STAR_RELU" , "1" ))) # Star-ReLU: learnable scale+bias on relu²
184188 perlayer_train_lr = bool (int (os .environ .get ("PERLAYER_TRAIN_LR" , "1" )))
185189 proj_lr_mult = float (os .environ .get ("PROJ_LR_MULT" , "1.5" )) # multiplier for mlp.proj (high quant damage)
186190 fc_lr_mult = float (os .environ .get ("FC_LR_MULT" , "0.7" )) # multiplier for mlp.fc (low quant damage)
@@ -1001,17 +1005,24 @@ def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor
10011005
10021006
10031007class MLP (nn .Module ):
1004- # relu^2 MLP from the original modded-nanogpt setup
1005- def __init__ (self , dim : int , mlp_mult : int , mlp_hidden : int = 0 ):
1008+ # Star-ReLU MLP: relu(x) ^2 with learnable per-channel scale+bias (MetaFormer)
1009+ def __init__ (self , dim : int , mlp_mult : int , mlp_hidden : int = 0 , star_relu : bool = False ):
10061010 super ().__init__ ()
10071011 hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim
10081012 self .fc = CastedLinear (dim , hidden , bias = False )
10091013 self .proj = CastedLinear (hidden , dim , bias = False )
10101014 self .proj ._zero_init = True
1015+ self .star_relu = star_relu
1016+ if star_relu :
1017+ self .star_scale = nn .Parameter (torch .ones (hidden , dtype = torch .float32 ))
1018+ self .star_bias = nn .Parameter (torch .zeros (hidden , dtype = torch .float32 ))
10111019
10121020 def forward (self , x : Tensor ) -> Tensor :
10131021 x = torch .relu (self .fc (x ))
1014- return self .proj (x .square ())
1022+ x = x .square ()
1023+ if self .star_relu :
1024+ x = x * self .star_scale .to (dtype = x .dtype ) + self .star_bias .to (dtype = x .dtype )
1025+ return self .proj (x )
10151026
10161027
10171028class Block (nn .Module ):
@@ -1030,14 +1041,15 @@ def __init__(
10301041 ln_scale : bool = False ,
10311042 value_residual : bool = False ,
10321043 gated_attention : bool = False ,
1044+ star_relu : bool = False ,
10331045 ):
10341046 super ().__init__ ()
10351047 self .attn_norm = RMSNorm ()
10361048 self .mlp_norm = RMSNorm ()
10371049 self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init ,
10381050 ntk_base_seq_len = ntk_base_seq_len , rope_dims = rope_dims ,
10391051 value_residual = value_residual , gated_attention = gated_attention )
1040- self .mlp = MLP (dim , mlp_mult , mlp_hidden = mlp_hidden )
1052+ self .mlp = MLP (dim , mlp_mult , mlp_hidden = mlp_hidden , star_relu = star_relu )
10411053 self .attn_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
10421054 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
10431055 self .resid_mix = nn .Parameter (torch .stack ((torch .ones (dim ), torch .zeros (dim ))).float ())
@@ -1114,6 +1126,29 @@ def forward(self, input_ids: Tensor) -> Tensor:
11141126 return self .scale .to (dtype = input_ids .dtype if input_ids .is_floating_point () else torch .bfloat16 ) * self .proj (self .embed (bucket_ids ))
11151127
11161128
1129+ # ── TrigramHash: hash-based trigram embedding ────────────────────────────────
1130+ class TrigramHashEmbedding (nn .Module ):
1131+ """Maps consecutive token triples to embeddings via xor hash.
1132+ Extends BigramHash to 3-token context window."""
1133+ def __init__ (self , num_buckets : int , hash_dim : int , model_dim : int ):
1134+ super ().__init__ ()
1135+ self .num_buckets = num_buckets
1136+ self .embed = nn .Embedding (num_buckets , hash_dim )
1137+ self .proj = nn .Linear (hash_dim , model_dim , bias = False )
1138+ self .scale = nn .Parameter (torch .tensor (0.05 , dtype = torch .float32 ))
1139+ nn .init .normal_ (self .embed .weight , std = 0.01 )
1140+ nn .init .zeros_ (self .proj .weight )
1141+
1142+ def forward (self , input_ids : Tensor ) -> Tensor :
1143+ # 3-token xor hash: h(t-2, t-1, t) = (p1*t-2) ^ (p2*t-1) ^ (p3*t) mod buckets
1144+ ids = input_ids .long ()
1145+ prev1 = torch .cat ([torch .zeros_like (ids [:, :1 ]), ids [:, :- 1 ]], dim = 1 )
1146+ prev2 = torch .cat ([torch .zeros_like (ids [:, :2 ]), ids [:, :- 2 ]], dim = 1 )
1147+ bucket_ids = (torch .bitwise_xor (torch .bitwise_xor (
1148+ 48271 * prev2 , 36313 * prev1 ), 27191 * ids ) % max (self .num_buckets - 1 , 1 ))
1149+ return self .scale .to (dtype = torch .bfloat16 ) * self .proj (self .embed (bucket_ids ))
1150+
1151+
11171152class GPT (nn .Module ):
11181153 def __init__ (
11191154 self ,
@@ -1135,6 +1170,9 @@ def __init__(
11351170 bigram_hash : bool = False ,
11361171 bigram_hash_buckets : int = 4096 ,
11371172 bigram_hash_dim : int = 128 ,
1173+ trigram_hash : bool = False ,
1174+ trigram_hash_buckets : int = 4096 ,
1175+ trigram_hash_dim : int = 32 ,
11381176 ortho_init : bool = True ,
11391177 xsa_last_n : int = 0 ,
11401178 ntk_base_seq_len : int = 0 ,
@@ -1146,6 +1184,7 @@ def __init__(
11461184 ve_layers : str = "9,10" ,
11471185 value_residual : bool = False ,
11481186 gated_attention : bool = False ,
1187+ star_relu : bool = False ,
11491188 ):
11501189 super ().__init__ ()
11511190 if logit_softcap <= 0.0 :
@@ -1160,6 +1199,7 @@ def __init__(
11601199 self .tok_emb = nn .Embedding (vocab_size , model_dim )
11611200 self .smear_gate = SmearGate (model_dim ) if smear_gate else None
11621201 self .bigram_hash = BigramHashEmbedding (bigram_hash_buckets , bigram_hash_dim , model_dim ) if bigram_hash else None
1202+ self .trigram_hash = TrigramHashEmbedding (trigram_hash_buckets , trigram_hash_dim , model_dim ) if trigram_hash else None
11631203 # Shared Value Embedding: one table, added to V in selected layers
11641204 self .ve_layer_indices = [int (x ) for x in ve_layers .split ("," ) if x .strip ()] if ve_enabled else []
11651205 kv_dim = num_kv_heads * (model_dim // num_heads )
@@ -1190,6 +1230,7 @@ def __init__(
11901230 ln_scale = ln_scale ,
11911231 value_residual = value_residual ,
11921232 gated_attention = gated_attention ,
1233+ star_relu = star_relu ,
11931234 )
11941235 for i in range (num_layers )
11951236 ]
@@ -1239,6 +1280,8 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
12391280 x = self .tok_emb (input_ids )
12401281 if self .bigram_hash is not None :
12411282 x = x + self .bigram_hash (input_ids )
1283+ if self .trigram_hash is not None :
1284+ x = x + self .trigram_hash (input_ids )
12421285 x = F .rms_norm (x , (x .size (- 1 ),))
12431286 if self .smear_gate is not None :
12441287 x = self .smear_gate (x )
@@ -1295,6 +1338,8 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
12951338 x = self .tok_emb (input_ids )
12961339 if self .bigram_hash is not None :
12971340 x = x + self .bigram_hash (input_ids )
1341+ if self .trigram_hash is not None :
1342+ x = x + self .trigram_hash (input_ids )
12981343 x = F .rms_norm (x , (x .size (- 1 ),))
12991344 if self .smear_gate is not None :
13001345 x = self .smear_gate (x )
@@ -1340,6 +1385,8 @@ def forward_logits_cached(
13401385 x = self .tok_emb (input_ids )
13411386 if self .bigram_hash is not None :
13421387 x = x + self .bigram_hash (input_ids )
1388+ if self .trigram_hash is not None :
1389+ x = x + self .trigram_hash (input_ids )
13431390 x = F .rms_norm (x , (x .size (- 1 ),))
13441391 if self .smear_gate is not None :
13451392 x = self .smear_gate (x )
@@ -2245,6 +2292,9 @@ def log0(msg: str, console: bool = True) -> None:
22452292 bigram_hash = args .bigram_hash ,
22462293 bigram_hash_buckets = args .bigram_hash_buckets ,
22472294 bigram_hash_dim = args .bigram_hash_dim ,
2295+ trigram_hash = args .trigram_hash ,
2296+ trigram_hash_buckets = args .trigram_hash_buckets ,
2297+ trigram_hash_dim = args .trigram_hash_dim ,
22482298 ortho_init = args .ortho_init ,
22492299 xsa_last_n = args .xsa_last_n ,
22502300 ntk_base_seq_len = args .train_seq_len if args .eval_seq_len > args .train_seq_len else 0 ,
@@ -2256,6 +2306,7 @@ def log0(msg: str, console: bool = True) -> None:
22562306 ve_layers = args .ve_layers ,
22572307 value_residual = args .value_residual ,
22582308 gated_attention = args .gated_attention ,
2309+ star_relu = args .star_relu ,
22592310 ).to (device ).bfloat16 ()
22602311 if args ._tier2 :
22612312 log0 (f"*** TIER2_MODE: proxy run max={ args .max_wallclock_seconds :.0f} s iters={ args .iterations } "
@@ -2339,6 +2390,8 @@ def log0(msg: str, console: bool = True) -> None:
23392390 # (when perlayer_train_lr, these are added to muon_param_groups directly)
23402391 if base_model .bigram_hash is not None and not args .perlayer_train_lr :
23412392 matrix_params .append (base_model .bigram_hash .proj .weight )
2393+ if base_model .trigram_hash is not None and not args .perlayer_train_lr :
2394+ matrix_params .append (base_model .trigram_hash .proj .weight )
23422395 scalar_params = [
23432396 p
23442397 for name , p in block_named_params
@@ -2352,6 +2405,8 @@ def log0(msg: str, console: bool = True) -> None:
23522405 # bigram_hash.scale is a learned scalar — AdamW at scalar_lr
23532406 if base_model .bigram_hash is not None :
23542407 scalar_params .append (base_model .bigram_hash .scale )
2408+ if base_model .trigram_hash is not None :
2409+ scalar_params .append (base_model .trigram_hash .scale )
23552410 # VE: scales go to scalar, proj to matrix, embed to tok group
23562411 if base_model .ve_shared is not None :
23572412 scalar_params .append (base_model .ve_shared .scale )
@@ -2365,6 +2420,8 @@ def log0(msg: str, console: bool = True) -> None:
23652420 embed_params = [base_model .tok_emb .weight ]
23662421 if base_model .bigram_hash is not None :
23672422 embed_params .append (base_model .bigram_hash .embed .weight )
2423+ if base_model .trigram_hash is not None :
2424+ embed_params .append (base_model .trigram_hash .embed .weight )
23682425 if base_model .ve_shared is not None :
23692426 embed_params .append (base_model .ve_shared .embed .weight )
23702427 optimizer_tok = torch .optim .AdamW (
@@ -2385,6 +2442,8 @@ def log0(msg: str, console: bool = True) -> None:
23852442 # Add bigram_hash.proj to "other" group
23862443 if base_model .bigram_hash is not None :
23872444 muon_param_groups [2 ]["params" ].append (base_model .bigram_hash .proj .weight )
2445+ if base_model .trigram_hash is not None :
2446+ muon_param_groups [2 ]["params" ].append (base_model .trigram_hash .proj .weight )
23882447 if base_model .ve_shared is not None and base_model .ve_shared .proj is not None :
23892448 muon_param_groups [2 ]["params" ].append (base_model .ve_shared .proj .weight )
23902449 muon_param_groups = [g for g in muon_param_groups if g ["params" ]]
0 commit comments