@@ -621,12 +621,6 @@ def __init__(
621621 self .q_gain = nn .Parameter (torch .full ((num_heads ,), qk_gain_init , dtype = torch .float32 ))
622622 self .rotary = Rotary (self .head_dim , base = rope_base )
623623
624- def _repeat_kv (self , k : Tensor , v : Tensor , bsz : int , seqlen : int ) -> tuple [Tensor , Tensor ]:
625- reps = self .num_heads // self .num_kv_heads
626- k = k [:, :, None , :, :].expand (- 1 , - 1 , reps , - 1 , - 1 ).reshape (bsz , self .num_heads , seqlen , self .head_dim )
627- v = v [:, :, None , :, :].expand (- 1 , - 1 , reps , - 1 , - 1 ).reshape (bsz , self .num_heads , seqlen , self .head_dim )
628- return k , v
629-
630624 def forward (self , x : Tensor ) -> Tensor :
631625 bsz , seqlen , dim = x .shape
632626 q = self .c_q (x ).reshape (bsz , seqlen , self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -638,10 +632,9 @@ def forward(self, x: Tensor) -> Tensor:
638632 q = apply_rotary_emb (q , cos , sin )
639633 k = apply_rotary_emb (k , cos , sin )
640634 q = q * self .q_gain .to (dtype = q .dtype )[None , :, None , None ]
641- if self .num_kv_heads != self .num_heads :
642- k , v = self ._repeat_kv (k , v , bsz , seqlen )
643635 y = F .scaled_dot_product_attention (
644636 q , k , v , attn_mask = None , is_causal = True ,
637+ enable_gqa = (self .num_kv_heads != self .num_heads ),
645638 ).transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
646639 return self .proj (y )
647640
0 commit comments