Skip to content

Commit 3b02f8f

Browse files
committed
phase-5: add xsa to deepest three layers
Add XSA_LAST_N env var (default 0 = disabled). When set (e.g. 3), the last N layers use Exclusive Self-Attention: subtracts the projection of attention output onto the value vector, encouraging capture of orthogonal information. Zero parameters added. Used by all top-5 submissions.
1 parent 1e8fb29 commit 3b02f8f

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

train_gpt.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class Hyperparameters:
105105
neural_temp = float(os.environ.get("NEURAL_TEMP", 0.85))
106106
ln_scale = bool(int(os.environ.get("LN_SCALE", "0")))
107107
rope_dims = int(os.environ.get("ROPE_DIMS", 0))
108+
xsa_last_n = int(os.environ.get("XSA_LAST_N", 0))
108109

109110
# -----------------------------
110111
# MUON OPTIMIZER
@@ -645,6 +646,7 @@ def __init__(
645646
rope_base: float,
646647
qk_gain_init: float,
647648
rope_dims: int = 0,
649+
use_xsa: bool = False,
648650
):
649651
super().__init__()
650652
if dim % num_heads != 0:
@@ -654,6 +656,7 @@ def __init__(
654656
self.num_heads = num_heads
655657
self.num_kv_heads = num_kv_heads
656658
self.head_dim = dim // num_heads
659+
self.use_xsa = use_xsa
657660
if self.head_dim % 2 != 0:
658661
raise ValueError("head_dim must be even for RoPE")
659662
self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim
@@ -686,6 +689,10 @@ def forward(self, x: Tensor) -> Tensor:
686689
is_causal=True,
687690
enable_gqa=(self.num_kv_heads != self.num_heads),
688691
)
692+
if self.use_xsa:
693+
v_expanded = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) if self.num_kv_heads != self.num_heads else v
694+
v_norm = F.normalize(v_expanded, dim=-1)
695+
y = y - (y * v_norm).sum(dim=-1, keepdim=True) * v_norm
689696
y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
690697
return self.proj(y)
691698

@@ -740,11 +747,13 @@ def __init__(
740747
qk_gain_init: float,
741748
ln_scale: float = 1.0,
742749
rope_dims: int = 0,
750+
use_xsa: bool = False,
743751
):
744752
super().__init__()
745753
self.attn_norm = RMSNorm()
746754
self.mlp_norm = RMSNorm()
747-
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims)
755+
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init,
756+
rope_dims=rope_dims, use_xsa=use_xsa)
748757
self.mlp = MLP(dim, mlp_mult)
749758
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
750759
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
@@ -766,7 +775,7 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
766775
num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float,
767776
logit_softcap: float, rope_base: float, qk_gain_init: float,
768777
bigram_vocab_size: int = 0, bigram_dim: int = 128,
769-
ln_scale: bool = False, rope_dims: int = 0):
778+
ln_scale: bool = False, rope_dims: int = 0, xsa_last_n: int = 0):
770779
super().__init__()
771780
if logit_softcap <= 0.0:
772781
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
@@ -781,7 +790,8 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
781790
self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32))
782791
self.blocks = nn.ModuleList([
783792
Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init,
784-
ln_scale=1.0 / (i + 1) ** 0.5 if ln_scale else 1.0, rope_dims=rope_dims)
793+
ln_scale=1.0 / (i + 1) ** 0.5 if ln_scale else 1.0, rope_dims=rope_dims,
794+
use_xsa=(i >= num_layers - xsa_last_n))
785795
for i in range(num_layers)
786796
])
787797
self.smear_gate = SmearGate(model_dim)
@@ -956,7 +966,7 @@ def log0(msg: str, console: bool = True) -> None:
956966
tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std,
957967
logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init,
958968
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
959-
ln_scale=args.ln_scale, rope_dims=args.rope_dims,
969+
ln_scale=args.ln_scale, rope_dims=args.rope_dims, xsa_last_n=args.xsa_last_n,
960970
).to(device).bfloat16()
961971
for module in base_model.modules():
962972
if isinstance(module, CastedLinear):

0 commit comments

Comments
 (0)