@@ -105,6 +105,7 @@ class Hyperparameters:
105105 neural_temp = float (os .environ .get ("NEURAL_TEMP" , 0.85 ))
106106 ln_scale = bool (int (os .environ .get ("LN_SCALE" , "0" )))
107107 rope_dims = int (os .environ .get ("ROPE_DIMS" , 0 ))
108+ xsa_last_n = int (os .environ .get ("XSA_LAST_N" , 0 ))
108109
109110# -----------------------------
110111# MUON OPTIMIZER
@@ -645,6 +646,7 @@ def __init__(
645646 rope_base : float ,
646647 qk_gain_init : float ,
647648 rope_dims : int = 0 ,
649+ use_xsa : bool = False ,
648650 ):
649651 super ().__init__ ()
650652 if dim % num_heads != 0 :
@@ -654,6 +656,7 @@ def __init__(
654656 self .num_heads = num_heads
655657 self .num_kv_heads = num_kv_heads
656658 self .head_dim = dim // num_heads
659+ self .use_xsa = use_xsa
657660 if self .head_dim % 2 != 0 :
658661 raise ValueError ("head_dim must be even for RoPE" )
659662 self .rope_dims = rope_dims if rope_dims > 0 else self .head_dim
@@ -686,6 +689,10 @@ def forward(self, x: Tensor) -> Tensor:
686689 is_causal = True ,
687690 enable_gqa = (self .num_kv_heads != self .num_heads ),
688691 )
692+ if self .use_xsa :
693+ v_expanded = v .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 ) if self .num_kv_heads != self .num_heads else v
694+ v_norm = F .normalize (v_expanded , dim = - 1 )
695+ y = y - (y * v_norm ).sum (dim = - 1 , keepdim = True ) * v_norm
689696 y = y .transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
690697 return self .proj (y )
691698
@@ -740,11 +747,13 @@ def __init__(
740747 qk_gain_init : float ,
741748 ln_scale : float = 1.0 ,
742749 rope_dims : int = 0 ,
750+ use_xsa : bool = False ,
743751 ):
744752 super ().__init__ ()
745753 self .attn_norm = RMSNorm ()
746754 self .mlp_norm = RMSNorm ()
747- self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init , rope_dims = rope_dims )
755+ self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init ,
756+ rope_dims = rope_dims , use_xsa = use_xsa )
748757 self .mlp = MLP (dim , mlp_mult )
749758 self .attn_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
750759 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
@@ -766,7 +775,7 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
766775 num_kv_heads : int , mlp_mult : int , tie_embeddings : bool , tied_embed_init_std : float ,
767776 logit_softcap : float , rope_base : float , qk_gain_init : float ,
768777 bigram_vocab_size : int = 0 , bigram_dim : int = 128 ,
769- ln_scale : bool = False , rope_dims : int = 0 ):
778+ ln_scale : bool = False , rope_dims : int = 0 , xsa_last_n : int = 0 ):
770779 super ().__init__ ()
771780 if logit_softcap <= 0.0 :
772781 raise ValueError (f"logit_softcap must be positive, got { logit_softcap } " )
@@ -781,7 +790,8 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
781790 self .skip_weights = nn .Parameter (torch .ones (self .num_skip_weights , model_dim , dtype = torch .float32 ))
782791 self .blocks = nn .ModuleList ([
783792 Block (model_dim , num_heads , num_kv_heads , mlp_mult , rope_base , qk_gain_init ,
784- ln_scale = 1.0 / (i + 1 ) ** 0.5 if ln_scale else 1.0 , rope_dims = rope_dims )
793+ ln_scale = 1.0 / (i + 1 ) ** 0.5 if ln_scale else 1.0 , rope_dims = rope_dims ,
794+ use_xsa = (i >= num_layers - xsa_last_n ))
785795 for i in range (num_layers )
786796 ])
787797 self .smear_gate = SmearGate (model_dim )
@@ -956,7 +966,7 @@ def log0(msg: str, console: bool = True) -> None:
956966 tie_embeddings = args .tie_embeddings , tied_embed_init_std = args .tied_embed_init_std ,
957967 logit_softcap = args .logit_softcap , rope_base = args .rope_base , qk_gain_init = args .qk_gain_init ,
958968 bigram_vocab_size = args .bigram_vocab_size , bigram_dim = args .bigram_dim ,
959- ln_scale = args .ln_scale , rope_dims = args .rope_dims ,
969+ ln_scale = args .ln_scale , rope_dims = args .rope_dims , xsa_last_n = args . xsa_last_n ,
960970 ).to (device ).bfloat16 ()
961971 for module in base_model .modules ():
962972 if isinstance (module , CastedLinear ):
0 commit comments