Skip to content

Commit 1e8fb29

Browse files
committed
phase-4: add partial rope support
Add ROPE_DIMS env var (default 0 = full head_dim). When set (e.g. 16), only the first ROPE_DIMS dimensions of each head get rotary embeddings; remaining dimensions attend without positional bias. Top submission uses 16/64 (25%). Zero parameters added.
1 parent 4949e3d commit 1e8fb29

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

train_gpt.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class Hyperparameters:
104104
ema_decay = float(os.environ.get("EMA_DECAY", 0.0))
105105
neural_temp = float(os.environ.get("NEURAL_TEMP", 0.85))
106106
ln_scale = bool(int(os.environ.get("LN_SCALE", "0")))
107+
rope_dims = int(os.environ.get("ROPE_DIMS", 0))
107108

108109
# -----------------------------
109110
# MUON OPTIMIZER
@@ -643,6 +644,7 @@ def __init__(
643644
num_kv_heads: int,
644645
rope_base: float,
645646
qk_gain_init: float,
647+
rope_dims: int = 0,
646648
):
647649
super().__init__()
648650
if dim % num_heads != 0:
@@ -654,14 +656,15 @@ def __init__(
654656
self.head_dim = dim // num_heads
655657
if self.head_dim % 2 != 0:
656658
raise ValueError("head_dim must be even for RoPE")
659+
self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim
657660
kv_dim = self.num_kv_heads * self.head_dim
658661
self.c_q = CastedLinear(dim, dim, bias=False)
659662
self.c_k = CastedLinear(dim, kv_dim, bias=False)
660663
self.c_v = CastedLinear(dim, kv_dim, bias=False)
661664
self.proj = CastedLinear(dim, dim, bias=False)
662665
self.proj._zero_init = True
663666
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
664-
self.rotary = Rotary(self.head_dim, base=rope_base)
667+
self.rotary = Rotary(self.rope_dims, base=rope_base)
665668

666669
def forward(self, x: Tensor) -> Tensor:
667670
bsz, seqlen, dim = x.shape
@@ -670,9 +673,10 @@ def forward(self, x: Tensor) -> Tensor:
670673
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
671674
q = F.rms_norm(q, (q.size(-1),))
672675
k = F.rms_norm(k, (k.size(-1),))
676+
rd = self.rope_dims
673677
cos, sin = self.rotary(seqlen, x.device, q.dtype)
674-
q = apply_rotary_emb(q, cos, sin)
675-
k = apply_rotary_emb(k, cos, sin)
678+
q = torch.cat((apply_rotary_emb(q[..., :rd], cos, sin), q[..., rd:]), dim=-1)
679+
k = torch.cat((apply_rotary_emb(k[..., :rd], cos, sin), k[..., rd:]), dim=-1)
676680
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
677681
y = F.scaled_dot_product_attention(
678682
q,
@@ -735,11 +739,12 @@ def __init__(
735739
rope_base: float,
736740
qk_gain_init: float,
737741
ln_scale: float = 1.0,
742+
rope_dims: int = 0,
738743
):
739744
super().__init__()
740745
self.attn_norm = RMSNorm()
741746
self.mlp_norm = RMSNorm()
742-
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
747+
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims)
743748
self.mlp = MLP(dim, mlp_mult)
744749
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
745750
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
@@ -760,7 +765,8 @@ class GPT(nn.Module):
760765
def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int,
761766
num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float,
762767
logit_softcap: float, rope_base: float, qk_gain_init: float,
763-
bigram_vocab_size: int = 0, bigram_dim: int = 128, ln_scale: bool = False):
768+
bigram_vocab_size: int = 0, bigram_dim: int = 128,
769+
ln_scale: bool = False, rope_dims: int = 0):
764770
super().__init__()
765771
if logit_softcap <= 0.0:
766772
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
@@ -775,7 +781,7 @@ def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads:
775781
self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32))
776782
self.blocks = nn.ModuleList([
777783
Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init,
778-
ln_scale=1.0 / (i + 1) ** 0.5 if ln_scale else 1.0)
784+
ln_scale=1.0 / (i + 1) ** 0.5 if ln_scale else 1.0, rope_dims=rope_dims)
779785
for i in range(num_layers)
780786
])
781787
self.smear_gate = SmearGate(model_dim)
@@ -949,7 +955,8 @@ def log0(msg: str, console: bool = True) -> None:
949955
num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult,
950956
tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std,
951957
logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init,
952-
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, ln_scale=args.ln_scale,
958+
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
959+
ln_scale=args.ln_scale, rope_dims=args.rope_dims,
953960
).to(device).bfloat16()
954961
for module in base_model.modules():
955962
if isinstance(module, CastedLinear):

0 commit comments

Comments
 (0)