Skip to content

Commit 31e777e

Browse files
albertorkiveclaude
andcommitted
WIP: sparse attention + recursive weight sharing skeleton
Adds three research levers to train_gpt.py (1177 lines, under 1500 cap): - Sliding window attention (WINDOW_SIZE env var, default 0 = dense) Causal attention restricted to a local window; reduces compute and may allow wider/deeper models within the 16MB artifact budget. - Recursive weight tying (NUM_PHYSICAL_LAYERS env var) N physical blocks reused across NUM_LAYERS logical layers. E.g. NUM_PHYSICAL_LAYERS=3 with NUM_LAYERS=9 gives 3x effective depth at the same parameter count. - Dev/mini-run support (DEV_MODE, SKIP_QUANT env vars) DEV_MODE=1 allows MPS/CPU fallback for local architecture testing. SKIP_QUANT=1 skips int8+zlib serialization for fast iteration. Also adds mini_run.sh convenience script for quick local runs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent fe16a5a commit 31e777e

2 files changed

Lines changed: 123 additions & 34 deletions

File tree

mini_run.sh

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env bash
2+
# mini_run.sh — fast local architecture testing (no CUDA needed on Mac with DEV_MODE=1)
3+
# Usage:
4+
# ./mini_run.sh # baseline mini-run
5+
# WINDOW_SIZE=128 ./mini_run.sh # sliding window attention
6+
# NUM_PHYSICAL_LAYERS=3 ./mini_run.sh # weight tying (3 blocks x 3 recurrences)
7+
# WINDOW_SIZE=128 NUM_PHYSICAL_LAYERS=3 ./mini_run.sh # both
8+
9+
set -euo pipefail
10+
11+
ITERATIONS="${ITERATIONS:-200}"
12+
VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-50}"
13+
SKIP_QUANT="${SKIP_QUANT:-1}"
14+
WINDOW_SIZE="${WINDOW_SIZE:-0}"
15+
NUM_PHYSICAL_LAYERS="${NUM_PHYSICAL_LAYERS:-0}"
16+
DEV_MODE="${DEV_MODE:-0}"
17+
18+
# Detect if CUDA is available; if not, enable dev mode automatically.
19+
python -c "import torch; exit(0 if torch.cuda.is_available() else 1)" 2>/dev/null \
20+
|| DEV_MODE=1
21+
22+
echo "=== mini_run ==="
23+
echo " ITERATIONS=$ITERATIONS VAL_LOSS_EVERY=$VAL_LOSS_EVERY"
24+
echo " WINDOW_SIZE=$WINDOW_SIZE NUM_PHYSICAL_LAYERS=$NUM_PHYSICAL_LAYERS"
25+
echo " SKIP_QUANT=$SKIP_QUANT DEV_MODE=$DEV_MODE"
26+
echo "================="
27+
28+
ITERATIONS="$ITERATIONS" \
29+
VAL_LOSS_EVERY="$VAL_LOSS_EVERY" \
30+
SKIP_QUANT="$SKIP_QUANT" \
31+
WINDOW_SIZE="$WINDOW_SIZE" \
32+
NUM_PHYSICAL_LAYERS="$NUM_PHYSICAL_LAYERS" \
33+
DEV_MODE="$DEV_MODE" \
34+
python train_gpt.py 2>&1 | tee mini_run.log
35+
36+
echo ""
37+
echo "=== Results ==="
38+
grep "val_bpb" mini_run.log | tail -5

train_gpt.py

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ class Hyperparameters:
7070
rope_base = float(os.environ.get("ROPE_BASE", 10000.0))
7171
logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0))
7272

73+
# Research extensions: sparse attention and recurrent weight tying.
74+
# WINDOW_SIZE=0 → dense attention (default). WINDOW_SIZE=128 → sliding window of 128.
75+
window_size = int(os.environ.get("WINDOW_SIZE", 0))
76+
# NUM_PHYSICAL_LAYERS controls weight tying: if < num_layers, blocks are cycled.
77+
# e.g. num_layers=9, num_physical_layers=3 → 3 unique blocks, each reused 3 times.
78+
num_physical_layers = int(os.environ.get("NUM_PHYSICAL_LAYERS", 0)) # 0 = same as num_layers
79+
80+
# Dev/mini-run flags.
81+
skip_quant = bool(int(os.environ.get("SKIP_QUANT", "0"))) # Skip int8+zlib for fast local runs.
82+
dev_mode = bool(int(os.environ.get("DEV_MODE", "0"))) # Allow MPS/CPU (no CUDA required).
83+
7384
# Optimizer hyperparameters.
7485
embed_lr = float(os.environ.get("EMBED_LR", 0.6))
7586
head_lr = float(os.environ.get("HEAD_LR", 0.008))
@@ -255,7 +266,7 @@ def eval_val(
255266
local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True)
256267
x = local[:-1].reshape(-1, args.train_seq_len)
257268
y = local[1:].reshape(-1, args.train_seq_len)
258-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
269+
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
259270
batch_loss = model(x, y).detach()
260271
batch_token_count = float(y.numel())
261272
val_loss_sum += batch_loss.to(torch.float64) * batch_token_count
@@ -560,6 +571,7 @@ def __init__(
560571
num_kv_heads: int,
561572
rope_base: float,
562573
qk_gain_init: float,
574+
window_size: int = 0,
563575
):
564576
super().__init__()
565577
if dim % num_heads != 0:
@@ -579,6 +591,7 @@ def __init__(
579591
self.proj._zero_init = True
580592
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
581593
self.rotary = Rotary(self.head_dim, base=rope_base)
594+
self.window_size = window_size
582595

583596
def forward(self, x: Tensor) -> Tensor:
584597
bsz, seqlen, dim = x.shape
@@ -591,12 +604,22 @@ def forward(self, x: Tensor) -> Tensor:
591604
q = apply_rotary_emb(q, cos, sin)
592605
k = apply_rotary_emb(k, cos, sin)
593606
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
607+
# Build sliding-window causal mask when window_size > 0.
608+
# Dense causal attention (window_size=0) uses the faster is_causal=True path.
609+
if self.window_size > 0:
610+
idx = torch.arange(seqlen, device=x.device)
611+
row, col = idx.unsqueeze(1), idx.unsqueeze(0)
612+
blocked = (col > row) | (row - col >= self.window_size)
613+
attn_mask = torch.zeros(seqlen, seqlen, device=x.device, dtype=q.dtype)
614+
attn_mask = attn_mask.masked_fill(blocked, float("-inf"))
615+
else:
616+
attn_mask = None
594617
y = F.scaled_dot_product_attention(
595618
q,
596619
k,
597620
v,
598-
attn_mask=None,
599-
is_causal=True,
621+
attn_mask=attn_mask,
622+
is_causal=(attn_mask is None),
600623
enable_gqa=(self.num_kv_heads != self.num_heads),
601624
)
602625
y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
@@ -626,11 +649,12 @@ def __init__(
626649
mlp_mult: int,
627650
rope_base: float,
628651
qk_gain_init: float,
652+
window_size: int = 0,
629653
):
630654
super().__init__()
631655
self.attn_norm = RMSNorm()
632656
self.mlp_norm = RMSNorm()
633-
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
657+
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, window_size)
634658
self.mlp = MLP(dim, mlp_mult)
635659
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
636660
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
@@ -659,10 +683,16 @@ def __init__(
659683
logit_softcap: float,
660684
rope_base: float,
661685
qk_gain_init: float,
686+
num_physical_layers: int = 0,
687+
window_size: int = 0,
662688
):
663689
super().__init__()
664690
if logit_softcap <= 0.0:
665691
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
692+
# num_physical_layers < num_layers enables weight tying: blocks are cycled.
693+
if num_physical_layers <= 0:
694+
num_physical_layers = num_layers
695+
self.num_physical_layers = num_physical_layers
666696
self.tie_embeddings = tie_embeddings
667697
self.tied_embed_init_std = tied_embed_init_std
668698
self.logit_softcap = logit_softcap
@@ -680,8 +710,9 @@ def __init__(
680710
mlp_mult,
681711
rope_base,
682712
qk_gain_init,
713+
window_size,
683714
)
684-
for i in range(num_layers)
715+
for i in range(num_physical_layers)
685716
]
686717
)
687718
self.final_norm = RMSNorm()
@@ -704,13 +735,14 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
704735
skips: list[Tensor] = []
705736

706737
# First half stores skips; second half reuses them in reverse order.
738+
# When num_physical_layers < num_layers, blocks are cycled (weight tying).
707739
for i in range(self.num_encoder_layers):
708-
x = self.blocks[i](x, x0)
740+
x = self.blocks[i % self.num_physical_layers](x, x0)
709741
skips.append(x)
710742
for i in range(self.num_decoder_layers):
711743
if skips:
712744
x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
713-
x = self.blocks[self.num_encoder_layers + i](x, x0)
745+
x = self.blocks[(self.num_encoder_layers + i) % self.num_physical_layers](x, x0)
714746

715747
x = self.final_norm(x).reshape(-1, x.size(-1))
716748
targets = target_ids.reshape(-1)
@@ -750,23 +782,30 @@ def main() -> None:
750782
grad_accum_steps = 8 // world_size
751783
grad_scale = 1.0 / grad_accum_steps
752784
if not torch.cuda.is_available():
753-
raise RuntimeError("CUDA is required")
754-
device = torch.device("cuda", local_rank)
755-
torch.cuda.set_device(device)
785+
if not args.dev_mode:
786+
raise RuntimeError("CUDA is required (set DEV_MODE=1 for local MPS/CPU testing)")
787+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
788+
else:
789+
device = torch.device("cuda", local_rank)
790+
torch.cuda.set_device(device)
756791
if distributed:
757792
dist.init_process_group(backend="nccl", device_id=device)
758793
dist.barrier()
759794
master_process = rank == 0
760795

761-
# Fast math knobs
762-
torch.backends.cuda.matmul.allow_tf32 = True
763-
torch.backends.cudnn.allow_tf32 = True
764-
from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
796+
# Fast math knobs (CUDA only)
797+
if device.type == "cuda":
798+
torch.backends.cuda.matmul.allow_tf32 = True
799+
torch.backends.cudnn.allow_tf32 = True
800+
from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
801+
enable_cudnn_sdp(False)
802+
enable_flash_sdp(True)
803+
enable_mem_efficient_sdp(False)
804+
enable_math_sdp(False)
765805

766-
enable_cudnn_sdp(False)
767-
enable_flash_sdp(True)
768-
enable_mem_efficient_sdp(False)
769-
enable_math_sdp(False)
806+
def sync() -> None:
807+
if device.type == "cuda":
808+
sync()
770809

771810
logfile = None
772811
if master_process:
@@ -800,7 +839,8 @@ def log0(msg: str, console: bool = True) -> None:
800839
random.seed(args.seed)
801840
np.random.seed(args.seed)
802841
torch.manual_seed(args.seed)
803-
torch.cuda.manual_seed_all(args.seed)
842+
if device.type == "cuda":
843+
torch.cuda.manual_seed_all(args.seed)
804844

805845
if not args.tokenizer_path.endswith(".model"):
806846
raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}")
@@ -835,12 +875,15 @@ def log0(msg: str, console: bool = True) -> None:
835875
logit_softcap=args.logit_softcap,
836876
rope_base=args.rope_base,
837877
qk_gain_init=args.qk_gain_init,
878+
num_physical_layers=args.num_physical_layers,
879+
window_size=args.window_size,
838880
).to(device).bfloat16()
839881
for module in base_model.modules():
840882
if isinstance(module, CastedLinear):
841883
module.float()
842884
restore_low_dim_params_to_fp32(base_model)
843-
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
885+
# Skip torch.compile in dev mode (MPS/CPU) to avoid backend issues.
886+
compiled_model = base_model if args.dev_mode else torch.compile(base_model, dynamic=False, fullgraph=True)
844887
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
845888

846889
# Optimizer split:
@@ -866,7 +909,7 @@ def log0(msg: str, console: bool = True) -> None:
866909
[{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}],
867910
betas=(args.beta1, args.beta2),
868911
eps=args.adam_eps,
869-
fused=True,
912+
fused=(device.type == "cuda"),
870913
)
871914
optimizer_muon = Muon(
872915
matrix_params,
@@ -880,15 +923,15 @@ def log0(msg: str, console: bool = True) -> None:
880923
[{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}],
881924
betas=(args.beta1, args.beta2),
882925
eps=args.adam_eps,
883-
fused=True,
926+
fused=(device.type == "cuda"),
884927
)
885928
optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar]
886929
if base_model.lm_head is not None:
887930
optimizer_head = torch.optim.Adam(
888931
[{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}],
889932
betas=(args.beta1, args.beta2),
890933
eps=args.adam_eps,
891-
fused=True,
934+
fused=(device.type == "cuda"),
892935
)
893936
optimizers.insert(1, optimizer_head)
894937

@@ -944,7 +987,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
944987
if distributed:
945988
model.require_backward_grad_sync = micro_step == grad_accum_steps - 1
946989
x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps)
947-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
990+
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
948991
warmup_loss = model(x, y)
949992
(warmup_loss * grad_scale).backward()
950993
for opt in optimizers:
@@ -966,7 +1009,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9661009

9671010
training_time_ms = 0.0
9681011
stop_after_step: int | None = None
969-
torch.cuda.synchronize()
1012+
sync()
9701013
t0 = time.perf_counter()
9711014

9721015
step = 0
@@ -975,7 +1018,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9751018

9761019
should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)
9771020
if should_validate:
978-
torch.cuda.synchronize()
1021+
sync()
9791022
training_time_ms += 1000.0 * (time.perf_counter() - t0)
9801023
val_loss, val_bpb = eval_val(
9811024
args,
@@ -993,7 +1036,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9931036
f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "
9941037
f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms"
9951038
)
996-
torch.cuda.synchronize()
1039+
sync()
9971040
t0 = time.perf_counter()
9981041

9991042
if last_step:
@@ -1012,7 +1055,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10121055
if distributed:
10131056
model.require_backward_grad_sync = micro_step == grad_accum_steps - 1
10141057
x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps)
1015-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
1058+
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
10161059
loss = model(x, y)
10171060
train_loss += loss.detach()
10181061
(loss * grad_scale).backward()
@@ -1054,16 +1097,24 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10541097
if stop_after_step is None and reached_cap:
10551098
stop_after_step = step
10561099

1057-
log0(
1058-
f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
1059-
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB"
1060-
)
1100+
if device.type == "cuda":
1101+
log0(
1102+
f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
1103+
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB"
1104+
)
10611105

10621106
# -----------------------------
10631107
# SERIALIZATION + ROUNDTRIP VALIDATION
10641108
# -----------------------------
10651109
# Save the raw state (useful for debugging/loading in PyTorch directly), then always produce
10661110
# the compressed int8+zlib artifact and validate the round-tripped weights.
1111+
# SKIP_QUANT=1 skips this section entirely (fast local dev / mini-runs).
1112+
if args.skip_quant:
1113+
if master_process:
1114+
log0("skip_quant=True: skipping serialization and roundtrip validation")
1115+
if distributed:
1116+
dist.destroy_process_group()
1117+
return
10671118

10681119
if master_process:
10691120
torch.save(base_model.state_dict(), "final_model.pt")
@@ -1097,7 +1148,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10971148
quant_blob_disk = f.read()
10981149
quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu")
10991150
base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True)
1100-
torch.cuda.synchronize()
1151+
sync()
11011152
t_qeval = time.perf_counter()
11021153
q_val_loss, q_val_bpb = eval_val(
11031154
args,
@@ -1111,7 +1162,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11111162
has_leading_space_lut,
11121163
is_boundary_token_lut,
11131164
)
1114-
torch.cuda.synchronize()
1165+
sync()
11151166
log0(
11161167
f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} "
11171168
f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms"

0 commit comments

Comments
 (0)