Skip to content

Commit 380d818

Browse files
committed
perf: use native enable_gqa — matches PR openai#114 attention path, 12ms/step faster
1 parent 6ca6087 commit 380d818

1 file changed

Lines changed: 1 addition & 8 deletions

File tree

records/track_10min_16mb/2026-03-20_StackedV1/train_gpt.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -621,12 +621,6 @@ def __init__(
621621
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
622622
self.rotary = Rotary(self.head_dim, base=rope_base)
623623

624-
def _repeat_kv(self, k: Tensor, v: Tensor, bsz: int, seqlen: int) -> tuple[Tensor, Tensor]:
625-
reps = self.num_heads // self.num_kv_heads
626-
k = k[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim)
627-
v = v[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim)
628-
return k, v
629-
630624
def forward(self, x: Tensor) -> Tensor:
631625
bsz, seqlen, dim = x.shape
632626
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
@@ -638,10 +632,9 @@ def forward(self, x: Tensor) -> Tensor:
638632
q = apply_rotary_emb(q, cos, sin)
639633
k = apply_rotary_emb(k, cos, sin)
640634
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
641-
if self.num_kv_heads != self.num_heads:
642-
k, v = self._repeat_kv(k, v, bsz, seqlen)
643635
y = F.scaled_dot_product_attention(
644636
q, k, v, attn_mask=None, is_causal=True,
637+
enable_gqa=(self.num_kv_heads != self.num_heads),
645638
).transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
646639
return self.proj(y)
647640

0 commit comments

Comments
 (0)