@@ -104,6 +104,7 @@ class Hyperparameters:
104104 ema_decay = float (os .environ .get ("EMA_DECAY" , 0.0 ))
105105 neural_temp = float (os .environ .get ("NEURAL_TEMP" , 0.85 ))
106106 ln_scale = bool (int (os .environ .get ("LN_SCALE" , "0" )))
107+ rope_dims = int (os .environ .get ("ROPE_DIMS" , 0 ))
107108
108109# -----------------------------
109110# MUON OPTIMIZER
@@ -643,6 +644,7 @@ def __init__(
643644 num_kv_heads : int ,
644645 rope_base : float ,
645646 qk_gain_init : float ,
647+ rope_dims : int = 0 ,
646648 ):
647649 super ().__init__ ()
648650 if dim % num_heads != 0 :
@@ -654,14 +656,15 @@ def __init__(
654656 self .head_dim = dim // num_heads
655657 if self .head_dim % 2 != 0 :
656658 raise ValueError ("head_dim must be even for RoPE" )
659+ self .rope_dims = rope_dims if rope_dims > 0 else self .head_dim
657660 kv_dim = self .num_kv_heads * self .head_dim
658661 self .c_q = CastedLinear (dim , dim , bias = False )
659662 self .c_k = CastedLinear (dim , kv_dim , bias = False )
660663 self .c_v = CastedLinear (dim , kv_dim , bias = False )
661664 self .proj = CastedLinear (dim , dim , bias = False )
662665 self .proj ._zero_init = True
663666 self .q_gain = nn .Parameter (torch .full ((num_heads ,), qk_gain_init , dtype = torch .float32 ))
664- self .rotary = Rotary (self .head_dim , base = rope_base )
667+ self .rotary = Rotary (self .rope_dims , base = rope_base )
665668
666669 def forward (self , x : Tensor ) -> Tensor :
667670 bsz , seqlen , dim = x .shape
@@ -670,9 +673,10 @@ def forward(self, x: Tensor) -> Tensor:
670673 v = self .c_v (x ).reshape (bsz , seqlen , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
671674 q = F .rms_norm (q , (q .size (- 1 ),))
672675 k = F .rms_norm (k , (k .size (- 1 ),))
676+ rd = self .rope_dims
673677 cos , sin = self .rotary (seqlen , x .device , q .dtype )
674- q = apply_rotary_emb (q , cos , sin )
675- k = apply_rotary_emb (k , cos , sin )
678+ q = torch . cat (( apply_rotary_emb (q [..., : rd ], cos , sin ), q [..., rd :]), dim = - 1 )
679+ k = torch . cat (( apply_rotary_emb (k [..., : rd ], cos , sin ), k [..., rd :]), dim = - 1 )
676680 q = q * self .q_gain .to (dtype = q .dtype )[None , :, None , None ]
677681 y = F .scaled_dot_product_attention (
678682 q ,
@@ -735,11 +739,12 @@ def __init__(
735739 rope_base : float ,
736740 qk_gain_init : float ,
737741 ln_scale : float = 1.0 ,
742+ rope_dims : int = 0 ,
738743 ):
739744 super ().__init__ ()
740745 self .attn_norm = RMSNorm ()
741746 self .mlp_norm = RMSNorm ()
742- self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init )
747+ self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init , rope_dims = rope_dims )
743748 self .mlp = MLP (dim , mlp_mult )
744749 self .attn_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
745750 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
@@ -760,7 +765,8 @@ class GPT(nn.Module):
760765 def __init__ (self , vocab_size : int , num_layers : int , model_dim : int , num_heads : int ,
761766 num_kv_heads : int , mlp_mult : int , tie_embeddings : bool , tied_embed_init_std : float ,
762767 logit_softcap : float , rope_base : float , qk_gain_init : float ,
763- bigram_vocab_size : int = 0 , bigram_dim : int = 128 , ln_scale : bool = False ):
768+ bigram_vocab_size : int = 0 , bigram_dim : int = 128 ,
769+ ln_scale : bool = False , rope_dims : int = 0 ):
764770 super ().__init__ ()
765771 if logit_softcap <= 0.0 :
766772 raise ValueError (f"logit_softcap must be positive, got { logit_softcap } " )
@@ -775,7 +781,7 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
775781 self .skip_weights = nn .Parameter (torch .ones (self .num_skip_weights , model_dim , dtype = torch .float32 ))
776782 self .blocks = nn .ModuleList ([
777783 Block (model_dim , num_heads , num_kv_heads , mlp_mult , rope_base , qk_gain_init ,
778- ln_scale = 1.0 / (i + 1 ) ** 0.5 if ln_scale else 1.0 )
784+ ln_scale = 1.0 / (i + 1 ) ** 0.5 if ln_scale else 1.0 , rope_dims = rope_dims )
779785 for i in range (num_layers )
780786 ])
781787 self .smear_gate = SmearGate (model_dim )
@@ -949,7 +955,8 @@ def log0(msg: str, console: bool = True) -> None:
949955 num_heads = args .num_heads , num_kv_heads = args .num_kv_heads , mlp_mult = args .mlp_mult ,
950956 tie_embeddings = args .tie_embeddings , tied_embed_init_std = args .tied_embed_init_std ,
951957 logit_softcap = args .logit_softcap , rope_base = args .rope_base , qk_gain_init = args .qk_gain_init ,
952- bigram_vocab_size = args .bigram_vocab_size , bigram_dim = args .bigram_dim , ln_scale = args .ln_scale ,
958+ bigram_vocab_size = args .bigram_vocab_size , bigram_dim = args .bigram_dim ,
959+ ln_scale = args .ln_scale , rope_dims = args .rope_dims ,
953960 ).to (device ).bfloat16 ()
954961 for module in base_model .modules ():
955962 if isinstance (module , CastedLinear ):
0 commit comments