Skip to content

Commit 56550e4

Browse files
committed
Add Star-ReLU, TrigramHash, and BigramHash 4096 buckets
Star-ReLU: learnable per-channel scale+bias on relu² activation (MetaFormer). Same architecture as PR openai#505's "SwiGLU" — 2 weight matrices, not gated MLP. Zero step time overhead, ~34K params (66KB fp16). TrigramHash: 3-token xor hash embedding extending BigramHash to trigram context. 4096 buckets, 32-dim, ~147K params (108KB int6). Independent contribution. BigramHash doubled to 4096 buckets (from 2048) for less collision. All features env-var controlled and default ON. Artifact headroom: ~466KB remaining (well within 16MB cap).
1 parent bfc61f1 commit 56550e4

1 file changed

Lines changed: 64 additions & 5 deletions

File tree

  • records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT

records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT/train_gpt.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,11 @@ class Hyperparameters:
108108
doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context
109109
smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # cheap bigram context at embedding layer
110110
bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "1"))) # hash-based bigram embedding
111-
bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 2048))
111+
bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) # doubled from 2048
112112
bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128))
113+
trigram_hash = bool(int(os.environ.get("TRIGRAM_HASH", "1"))) # 3-token context hash embedding
114+
trigram_hash_buckets = int(os.environ.get("TRIGRAM_HASH_BUCKETS", 4096))
115+
trigram_hash_dim = int(os.environ.get("TRIGRAM_HASH_DIM", 32)) # smaller dim than bigram (additive)
113116
swa = bool(int(os.environ.get("SWA", "0"))) # stochastic weight averaging (disabled: EMA preferred)
114117
swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5))
115118
xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers (0=disabled)
@@ -181,6 +184,7 @@ class Hyperparameters:
181184
gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "1"))) # per-head sigmoid gate after SDPA
182185
# Per-layer lr: MLP proj (high quant damage) gets higher lr, MLP fc (low damage) gets lower lr
183186
# Based on our 34-config ablation showing 3.4x damage ratio between proj and fc weights
187+
star_relu = bool(int(os.environ.get("STAR_RELU", "1"))) # Star-ReLU: learnable scale+bias on relu²
184188
perlayer_train_lr = bool(int(os.environ.get("PERLAYER_TRAIN_LR", "1")))
185189
proj_lr_mult = float(os.environ.get("PROJ_LR_MULT", "1.5")) # multiplier for mlp.proj (high quant damage)
186190
fc_lr_mult = float(os.environ.get("FC_LR_MULT", "0.7")) # multiplier for mlp.fc (low quant damage)
@@ -1001,17 +1005,24 @@ def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor
10011005

10021006

10031007
class MLP(nn.Module):
1004-
# relu^2 MLP from the original modded-nanogpt setup
1005-
def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0):
1008+
# Star-ReLU MLP: relu(x)^2 with learnable per-channel scale+bias (MetaFormer)
1009+
def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0, star_relu: bool = False):
10061010
super().__init__()
10071011
hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim
10081012
self.fc = CastedLinear(dim, hidden, bias=False)
10091013
self.proj = CastedLinear(hidden, dim, bias=False)
10101014
self.proj._zero_init = True
1015+
self.star_relu = star_relu
1016+
if star_relu:
1017+
self.star_scale = nn.Parameter(torch.ones(hidden, dtype=torch.float32))
1018+
self.star_bias = nn.Parameter(torch.zeros(hidden, dtype=torch.float32))
10111019

10121020
def forward(self, x: Tensor) -> Tensor:
10131021
x = torch.relu(self.fc(x))
1014-
return self.proj(x.square())
1022+
x = x.square()
1023+
if self.star_relu:
1024+
x = x * self.star_scale.to(dtype=x.dtype) + self.star_bias.to(dtype=x.dtype)
1025+
return self.proj(x)
10151026

10161027

10171028
class Block(nn.Module):
@@ -1030,14 +1041,15 @@ def __init__(
10301041
ln_scale: bool = False,
10311042
value_residual: bool = False,
10321043
gated_attention: bool = False,
1044+
star_relu: bool = False,
10331045
):
10341046
super().__init__()
10351047
self.attn_norm = RMSNorm()
10361048
self.mlp_norm = RMSNorm()
10371049
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init,
10381050
ntk_base_seq_len=ntk_base_seq_len, rope_dims=rope_dims,
10391051
value_residual=value_residual, gated_attention=gated_attention)
1040-
self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden)
1052+
self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden, star_relu=star_relu)
10411053
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
10421054
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
10431055
self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
@@ -1114,6 +1126,29 @@ def forward(self, input_ids: Tensor) -> Tensor:
11141126
return self.scale.to(dtype=input_ids.dtype if input_ids.is_floating_point() else torch.bfloat16) * self.proj(self.embed(bucket_ids))
11151127

11161128

1129+
# ── TrigramHash: hash-based trigram embedding ────────────────────────────────
1130+
class TrigramHashEmbedding(nn.Module):
1131+
"""Maps consecutive token triples to embeddings via xor hash.
1132+
Extends BigramHash to 3-token context window."""
1133+
def __init__(self, num_buckets: int, hash_dim: int, model_dim: int):
1134+
super().__init__()
1135+
self.num_buckets = num_buckets
1136+
self.embed = nn.Embedding(num_buckets, hash_dim)
1137+
self.proj = nn.Linear(hash_dim, model_dim, bias=False)
1138+
self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32))
1139+
nn.init.normal_(self.embed.weight, std=0.01)
1140+
nn.init.zeros_(self.proj.weight)
1141+
1142+
def forward(self, input_ids: Tensor) -> Tensor:
1143+
# 3-token xor hash: h(t-2, t-1, t) = (p1*t-2) ^ (p2*t-1) ^ (p3*t) mod buckets
1144+
ids = input_ids.long()
1145+
prev1 = torch.cat([torch.zeros_like(ids[:, :1]), ids[:, :-1]], dim=1)
1146+
prev2 = torch.cat([torch.zeros_like(ids[:, :2]), ids[:, :-2]], dim=1)
1147+
bucket_ids = (torch.bitwise_xor(torch.bitwise_xor(
1148+
48271 * prev2, 36313 * prev1), 27191 * ids) % max(self.num_buckets - 1, 1))
1149+
return self.scale.to(dtype=torch.bfloat16) * self.proj(self.embed(bucket_ids))
1150+
1151+
11171152
class GPT(nn.Module):
11181153
def __init__(
11191154
self,
@@ -1135,6 +1170,9 @@ def __init__(
11351170
bigram_hash: bool = False,
11361171
bigram_hash_buckets: int = 4096,
11371172
bigram_hash_dim: int = 128,
1173+
trigram_hash: bool = False,
1174+
trigram_hash_buckets: int = 4096,
1175+
trigram_hash_dim: int = 32,
11381176
ortho_init: bool = True,
11391177
xsa_last_n: int = 0,
11401178
ntk_base_seq_len: int = 0,
@@ -1146,6 +1184,7 @@ def __init__(
11461184
ve_layers: str = "9,10",
11471185
value_residual: bool = False,
11481186
gated_attention: bool = False,
1187+
star_relu: bool = False,
11491188
):
11501189
super().__init__()
11511190
if logit_softcap <= 0.0:
@@ -1160,6 +1199,7 @@ def __init__(
11601199
self.tok_emb = nn.Embedding(vocab_size, model_dim)
11611200
self.smear_gate = SmearGate(model_dim) if smear_gate else None
11621201
self.bigram_hash = BigramHashEmbedding(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash else None
1202+
self.trigram_hash = TrigramHashEmbedding(trigram_hash_buckets, trigram_hash_dim, model_dim) if trigram_hash else None
11631203
# Shared Value Embedding: one table, added to V in selected layers
11641204
self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else []
11651205
kv_dim = num_kv_heads * (model_dim // num_heads)
@@ -1190,6 +1230,7 @@ def __init__(
11901230
ln_scale=ln_scale,
11911231
value_residual=value_residual,
11921232
gated_attention=gated_attention,
1233+
star_relu=star_relu,
11931234
)
11941235
for i in range(num_layers)
11951236
]
@@ -1239,6 +1280,8 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
12391280
x = self.tok_emb(input_ids)
12401281
if self.bigram_hash is not None:
12411282
x = x + self.bigram_hash(input_ids)
1283+
if self.trigram_hash is not None:
1284+
x = x + self.trigram_hash(input_ids)
12421285
x = F.rms_norm(x, (x.size(-1),))
12431286
if self.smear_gate is not None:
12441287
x = self.smear_gate(x)
@@ -1295,6 +1338,8 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
12951338
x = self.tok_emb(input_ids)
12961339
if self.bigram_hash is not None:
12971340
x = x + self.bigram_hash(input_ids)
1341+
if self.trigram_hash is not None:
1342+
x = x + self.trigram_hash(input_ids)
12981343
x = F.rms_norm(x, (x.size(-1),))
12991344
if self.smear_gate is not None:
13001345
x = self.smear_gate(x)
@@ -1340,6 +1385,8 @@ def forward_logits_cached(
13401385
x = self.tok_emb(input_ids)
13411386
if self.bigram_hash is not None:
13421387
x = x + self.bigram_hash(input_ids)
1388+
if self.trigram_hash is not None:
1389+
x = x + self.trigram_hash(input_ids)
13431390
x = F.rms_norm(x, (x.size(-1),))
13441391
if self.smear_gate is not None:
13451392
x = self.smear_gate(x)
@@ -2245,6 +2292,9 @@ def log0(msg: str, console: bool = True) -> None:
22452292
bigram_hash=args.bigram_hash,
22462293
bigram_hash_buckets=args.bigram_hash_buckets,
22472294
bigram_hash_dim=args.bigram_hash_dim,
2295+
trigram_hash=args.trigram_hash,
2296+
trigram_hash_buckets=args.trigram_hash_buckets,
2297+
trigram_hash_dim=args.trigram_hash_dim,
22482298
ortho_init=args.ortho_init,
22492299
xsa_last_n=args.xsa_last_n,
22502300
ntk_base_seq_len=args.train_seq_len if args.eval_seq_len > args.train_seq_len else 0,
@@ -2256,6 +2306,7 @@ def log0(msg: str, console: bool = True) -> None:
22562306
ve_layers=args.ve_layers,
22572307
value_residual=args.value_residual,
22582308
gated_attention=args.gated_attention,
2309+
star_relu=args.star_relu,
22592310
).to(device).bfloat16()
22602311
if args._tier2:
22612312
log0(f"*** TIER2_MODE: proxy run max={args.max_wallclock_seconds:.0f}s iters={args.iterations} "
@@ -2339,6 +2390,8 @@ def log0(msg: str, console: bool = True) -> None:
23392390
# (when perlayer_train_lr, these are added to muon_param_groups directly)
23402391
if base_model.bigram_hash is not None and not args.perlayer_train_lr:
23412392
matrix_params.append(base_model.bigram_hash.proj.weight)
2393+
if base_model.trigram_hash is not None and not args.perlayer_train_lr:
2394+
matrix_params.append(base_model.trigram_hash.proj.weight)
23422395
scalar_params = [
23432396
p
23442397
for name, p in block_named_params
@@ -2352,6 +2405,8 @@ def log0(msg: str, console: bool = True) -> None:
23522405
# bigram_hash.scale is a learned scalar — AdamW at scalar_lr
23532406
if base_model.bigram_hash is not None:
23542407
scalar_params.append(base_model.bigram_hash.scale)
2408+
if base_model.trigram_hash is not None:
2409+
scalar_params.append(base_model.trigram_hash.scale)
23552410
# VE: scales go to scalar, proj to matrix, embed to tok group
23562411
if base_model.ve_shared is not None:
23572412
scalar_params.append(base_model.ve_shared.scale)
@@ -2365,6 +2420,8 @@ def log0(msg: str, console: bool = True) -> None:
23652420
embed_params = [base_model.tok_emb.weight]
23662421
if base_model.bigram_hash is not None:
23672422
embed_params.append(base_model.bigram_hash.embed.weight)
2423+
if base_model.trigram_hash is not None:
2424+
embed_params.append(base_model.trigram_hash.embed.weight)
23682425
if base_model.ve_shared is not None:
23692426
embed_params.append(base_model.ve_shared.embed.weight)
23702427
optimizer_tok = torch.optim.AdamW(
@@ -2385,6 +2442,8 @@ def log0(msg: str, console: bool = True) -> None:
23852442
# Add bigram_hash.proj to "other" group
23862443
if base_model.bigram_hash is not None:
23872444
muon_param_groups[2]["params"].append(base_model.bigram_hash.proj.weight)
2445+
if base_model.trigram_hash is not None:
2446+
muon_param_groups[2]["params"].append(base_model.trigram_hash.proj.weight)
23882447
if base_model.ve_shared is not None and base_model.ve_shared.proj is not None:
23892448
muon_param_groups[2]["params"].append(base_model.ve_shared.proj.weight)
23902449
muon_param_groups = [g for g in muon_param_groups if g["params"]]

0 commit comments

Comments
 (0)