@@ -70,6 +70,17 @@ class Hyperparameters:
7070 rope_base = float (os .environ .get ("ROPE_BASE" , 10000.0 ))
7171 logit_softcap = float (os .environ .get ("LOGIT_SOFTCAP" , 30.0 ))
7272
73+ # Research extensions: sparse attention and recurrent weight tying.
74+ # WINDOW_SIZE=0 → dense attention (default). WINDOW_SIZE=128 → sliding window of 128.
75+ window_size = int (os .environ .get ("WINDOW_SIZE" , 0 ))
76+ # NUM_PHYSICAL_LAYERS controls weight tying: if < num_layers, blocks are cycled.
77+ # e.g. num_layers=9, num_physical_layers=3 → 3 unique blocks, each reused 3 times.
78+ num_physical_layers = int (os .environ .get ("NUM_PHYSICAL_LAYERS" , 0 )) # 0 = same as num_layers
79+
80+ # Dev/mini-run flags.
81+ skip_quant = bool (int (os .environ .get ("SKIP_QUANT" , "0" ))) # Skip int8+zlib for fast local runs.
82+ dev_mode = bool (int (os .environ .get ("DEV_MODE" , "0" ))) # Allow MPS/CPU (no CUDA required).
83+
7384 # Optimizer hyperparameters.
7485 embed_lr = float (os .environ .get ("EMBED_LR" , 0.6 ))
7586 head_lr = float (os .environ .get ("HEAD_LR" , 0.008 ))
@@ -255,7 +266,7 @@ def eval_val(
255266 local = val_tokens [raw_start :raw_end ].to (device = device , dtype = torch .int64 , non_blocking = True )
256267 x = local [:- 1 ].reshape (- 1 , args .train_seq_len )
257268 y = local [1 :].reshape (- 1 , args .train_seq_len )
258- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
269+ with torch .autocast (device_type = device . type , dtype = torch .bfloat16 , enabled = True ):
259270 batch_loss = model (x , y ).detach ()
260271 batch_token_count = float (y .numel ())
261272 val_loss_sum += batch_loss .to (torch .float64 ) * batch_token_count
@@ -560,6 +571,7 @@ def __init__(
560571 num_kv_heads : int ,
561572 rope_base : float ,
562573 qk_gain_init : float ,
574+ window_size : int = 0 ,
563575 ):
564576 super ().__init__ ()
565577 if dim % num_heads != 0 :
@@ -579,6 +591,7 @@ def __init__(
579591 self .proj ._zero_init = True
580592 self .q_gain = nn .Parameter (torch .full ((num_heads ,), qk_gain_init , dtype = torch .float32 ))
581593 self .rotary = Rotary (self .head_dim , base = rope_base )
594+ self .window_size = window_size
582595
583596 def forward (self , x : Tensor ) -> Tensor :
584597 bsz , seqlen , dim = x .shape
@@ -591,12 +604,22 @@ def forward(self, x: Tensor) -> Tensor:
591604 q = apply_rotary_emb (q , cos , sin )
592605 k = apply_rotary_emb (k , cos , sin )
593606 q = q * self .q_gain .to (dtype = q .dtype )[None , :, None , None ]
607+ # Build sliding-window causal mask when window_size > 0.
608+ # Dense causal attention (window_size=0) uses the faster is_causal=True path.
609+ if self .window_size > 0 :
610+ idx = torch .arange (seqlen , device = x .device )
611+ row , col = idx .unsqueeze (1 ), idx .unsqueeze (0 )
612+ blocked = (col > row ) | (row - col >= self .window_size )
613+ attn_mask = torch .zeros (seqlen , seqlen , device = x .device , dtype = q .dtype )
614+ attn_mask = attn_mask .masked_fill (blocked , float ("-inf" ))
615+ else :
616+ attn_mask = None
594617 y = F .scaled_dot_product_attention (
595618 q ,
596619 k ,
597620 v ,
598- attn_mask = None ,
599- is_causal = True ,
621+ attn_mask = attn_mask ,
622+ is_causal = ( attn_mask is None ) ,
600623 enable_gqa = (self .num_kv_heads != self .num_heads ),
601624 )
602625 y = y .transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
@@ -626,11 +649,12 @@ def __init__(
626649 mlp_mult : int ,
627650 rope_base : float ,
628651 qk_gain_init : float ,
652+ window_size : int = 0 ,
629653 ):
630654 super ().__init__ ()
631655 self .attn_norm = RMSNorm ()
632656 self .mlp_norm = RMSNorm ()
633- self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init )
657+ self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init , window_size )
634658 self .mlp = MLP (dim , mlp_mult )
635659 self .attn_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
636660 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
@@ -659,10 +683,16 @@ def __init__(
659683 logit_softcap : float ,
660684 rope_base : float ,
661685 qk_gain_init : float ,
686+ num_physical_layers : int = 0 ,
687+ window_size : int = 0 ,
662688 ):
663689 super ().__init__ ()
664690 if logit_softcap <= 0.0 :
665691 raise ValueError (f"logit_softcap must be positive, got { logit_softcap } " )
692+ # num_physical_layers < num_layers enables weight tying: blocks are cycled.
693+ if num_physical_layers <= 0 :
694+ num_physical_layers = num_layers
695+ self .num_physical_layers = num_physical_layers
666696 self .tie_embeddings = tie_embeddings
667697 self .tied_embed_init_std = tied_embed_init_std
668698 self .logit_softcap = logit_softcap
@@ -680,8 +710,9 @@ def __init__(
680710 mlp_mult ,
681711 rope_base ,
682712 qk_gain_init ,
713+ window_size ,
683714 )
684- for i in range (num_layers )
715+ for i in range (num_physical_layers )
685716 ]
686717 )
687718 self .final_norm = RMSNorm ()
@@ -704,13 +735,14 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
704735 skips : list [Tensor ] = []
705736
706737 # First half stores skips; second half reuses them in reverse order.
738+ # When num_physical_layers < num_layers, blocks are cycled (weight tying).
707739 for i in range (self .num_encoder_layers ):
708- x = self .blocks [i ](x , x0 )
740+ x = self .blocks [i % self . num_physical_layers ](x , x0 )
709741 skips .append (x )
710742 for i in range (self .num_decoder_layers ):
711743 if skips :
712744 x = x + self .skip_weights [i ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
713- x = self .blocks [self .num_encoder_layers + i ](x , x0 )
745+ x = self .blocks [( self .num_encoder_layers + i ) % self . num_physical_layers ](x , x0 )
714746
715747 x = self .final_norm (x ).reshape (- 1 , x .size (- 1 ))
716748 targets = target_ids .reshape (- 1 )
@@ -750,23 +782,30 @@ def main() -> None:
750782 grad_accum_steps = 8 // world_size
751783 grad_scale = 1.0 / grad_accum_steps
752784 if not torch .cuda .is_available ():
753- raise RuntimeError ("CUDA is required" )
754- device = torch .device ("cuda" , local_rank )
755- torch .cuda .set_device (device )
785+ if not args .dev_mode :
786+ raise RuntimeError ("CUDA is required (set DEV_MODE=1 for local MPS/CPU testing)" )
787+ device = torch .device ("mps" if torch .backends .mps .is_available () else "cpu" )
788+ else :
789+ device = torch .device ("cuda" , local_rank )
790+ torch .cuda .set_device (device )
756791 if distributed :
757792 dist .init_process_group (backend = "nccl" , device_id = device )
758793 dist .barrier ()
759794 master_process = rank == 0
760795
761- # Fast math knobs
762- torch .backends .cuda .matmul .allow_tf32 = True
763- torch .backends .cudnn .allow_tf32 = True
764- from torch .backends .cuda import enable_cudnn_sdp , enable_flash_sdp , enable_math_sdp , enable_mem_efficient_sdp
796+ # Fast math knobs (CUDA only)
797+ if device .type == "cuda" :
798+ torch .backends .cuda .matmul .allow_tf32 = True
799+ torch .backends .cudnn .allow_tf32 = True
800+ from torch .backends .cuda import enable_cudnn_sdp , enable_flash_sdp , enable_math_sdp , enable_mem_efficient_sdp
801+ enable_cudnn_sdp (False )
802+ enable_flash_sdp (True )
803+ enable_mem_efficient_sdp (False )
804+ enable_math_sdp (False )
765805
766- enable_cudnn_sdp (False )
767- enable_flash_sdp (True )
768- enable_mem_efficient_sdp (False )
769- enable_math_sdp (False )
806+ def sync () -> None :
807+ if device .type == "cuda" :
808+ sync ()
770809
771810 logfile = None
772811 if master_process :
@@ -800,7 +839,8 @@ def log0(msg: str, console: bool = True) -> None:
800839 random .seed (args .seed )
801840 np .random .seed (args .seed )
802841 torch .manual_seed (args .seed )
803- torch .cuda .manual_seed_all (args .seed )
842+ if device .type == "cuda" :
843+ torch .cuda .manual_seed_all (args .seed )
804844
805845 if not args .tokenizer_path .endswith (".model" ):
806846 raise ValueError (f"Script only setup for SentencePiece .model file: { args .tokenizer_path } " )
@@ -835,12 +875,15 @@ def log0(msg: str, console: bool = True) -> None:
835875 logit_softcap = args .logit_softcap ,
836876 rope_base = args .rope_base ,
837877 qk_gain_init = args .qk_gain_init ,
878+ num_physical_layers = args .num_physical_layers ,
879+ window_size = args .window_size ,
838880 ).to (device ).bfloat16 ()
839881 for module in base_model .modules ():
840882 if isinstance (module , CastedLinear ):
841883 module .float ()
842884 restore_low_dim_params_to_fp32 (base_model )
843- compiled_model = torch .compile (base_model , dynamic = False , fullgraph = True )
885+ # Skip torch.compile in dev mode (MPS/CPU) to avoid backend issues.
886+ compiled_model = base_model if args .dev_mode else torch .compile (base_model , dynamic = False , fullgraph = True )
844887 model : nn .Module = DDP (compiled_model , device_ids = [local_rank ], broadcast_buffers = False ) if distributed else compiled_model
845888
846889 # Optimizer split:
@@ -866,7 +909,7 @@ def log0(msg: str, console: bool = True) -> None:
866909 [{"params" : [base_model .tok_emb .weight ], "lr" : token_lr , "base_lr" : token_lr }],
867910 betas = (args .beta1 , args .beta2 ),
868911 eps = args .adam_eps ,
869- fused = True ,
912+ fused = ( device . type == "cuda" ) ,
870913 )
871914 optimizer_muon = Muon (
872915 matrix_params ,
@@ -880,15 +923,15 @@ def log0(msg: str, console: bool = True) -> None:
880923 [{"params" : scalar_params , "lr" : args .scalar_lr , "base_lr" : args .scalar_lr }],
881924 betas = (args .beta1 , args .beta2 ),
882925 eps = args .adam_eps ,
883- fused = True ,
926+ fused = ( device . type == "cuda" ) ,
884927 )
885928 optimizers : list [torch .optim .Optimizer ] = [optimizer_tok , optimizer_muon , optimizer_scalar ]
886929 if base_model .lm_head is not None :
887930 optimizer_head = torch .optim .Adam (
888931 [{"params" : [base_model .lm_head .weight ], "lr" : args .head_lr , "base_lr" : args .head_lr }],
889932 betas = (args .beta1 , args .beta2 ),
890933 eps = args .adam_eps ,
891- fused = True ,
934+ fused = ( device . type == "cuda" ) ,
892935 )
893936 optimizers .insert (1 , optimizer_head )
894937
@@ -944,7 +987,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
944987 if distributed :
945988 model .require_backward_grad_sync = micro_step == grad_accum_steps - 1
946989 x , y = train_loader .next_batch (args .train_batch_tokens , args .train_seq_len , grad_accum_steps )
947- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
990+ with torch .autocast (device_type = device . type , dtype = torch .bfloat16 , enabled = True ):
948991 warmup_loss = model (x , y )
949992 (warmup_loss * grad_scale ).backward ()
950993 for opt in optimizers :
@@ -966,7 +1009,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9661009
9671010 training_time_ms = 0.0
9681011 stop_after_step : int | None = None
969- torch . cuda . synchronize ()
1012+ sync ()
9701013 t0 = time .perf_counter ()
9711014
9721015 step = 0
@@ -975,7 +1018,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9751018
9761019 should_validate = last_step or (args .val_loss_every > 0 and step % args .val_loss_every == 0 )
9771020 if should_validate :
978- torch . cuda . synchronize ()
1021+ sync ()
9791022 training_time_ms += 1000.0 * (time .perf_counter () - t0 )
9801023 val_loss , val_bpb = eval_val (
9811024 args ,
@@ -993,7 +1036,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
9931036 f"step:{ step } /{ args .iterations } val_loss:{ val_loss :.4f} val_bpb:{ val_bpb :.4f} "
9941037 f"train_time:{ training_time_ms :.0f} ms step_avg:{ training_time_ms / max (step , 1 ):.2f} ms"
9951038 )
996- torch . cuda . synchronize ()
1039+ sync ()
9971040 t0 = time .perf_counter ()
9981041
9991042 if last_step :
@@ -1012,7 +1055,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10121055 if distributed :
10131056 model .require_backward_grad_sync = micro_step == grad_accum_steps - 1
10141057 x , y = train_loader .next_batch (args .train_batch_tokens , args .train_seq_len , grad_accum_steps )
1015- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
1058+ with torch .autocast (device_type = device . type , dtype = torch .bfloat16 , enabled = True ):
10161059 loss = model (x , y )
10171060 train_loss += loss .detach ()
10181061 (loss * grad_scale ).backward ()
@@ -1054,16 +1097,24 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10541097 if stop_after_step is None and reached_cap :
10551098 stop_after_step = step
10561099
1057- log0 (
1058- f"peak memory allocated: { torch .cuda .max_memory_allocated () // 1024 // 1024 } MiB "
1059- f"reserved: { torch .cuda .max_memory_reserved () // 1024 // 1024 } MiB"
1060- )
1100+ if device .type == "cuda" :
1101+ log0 (
1102+ f"peak memory allocated: { torch .cuda .max_memory_allocated () // 1024 // 1024 } MiB "
1103+ f"reserved: { torch .cuda .max_memory_reserved () // 1024 // 1024 } MiB"
1104+ )
10611105
10621106 # -----------------------------
10631107 # SERIALIZATION + ROUNDTRIP VALIDATION
10641108 # -----------------------------
10651109 # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce
10661110 # the compressed int8+zlib artifact and validate the round-tripped weights.
1111+ # SKIP_QUANT=1 skips this section entirely (fast local dev / mini-runs).
1112+ if args .skip_quant :
1113+ if master_process :
1114+ log0 ("skip_quant=True: skipping serialization and roundtrip validation" )
1115+ if distributed :
1116+ dist .destroy_process_group ()
1117+ return
10671118
10681119 if master_process :
10691120 torch .save (base_model .state_dict (), "final_model.pt" )
@@ -1097,7 +1148,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10971148 quant_blob_disk = f .read ()
10981149 quant_state = torch .load (io .BytesIO (zlib .decompress (quant_blob_disk )), map_location = "cpu" )
10991150 base_model .load_state_dict (dequantize_state_dict_int8 (quant_state ), strict = True )
1100- torch . cuda . synchronize ()
1151+ sync ()
11011152 t_qeval = time .perf_counter ()
11021153 q_val_loss , q_val_bpb = eval_val (
11031154 args ,
@@ -1111,7 +1162,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11111162 has_leading_space_lut ,
11121163 is_boundary_token_lut ,
11131164 )
1114- torch . cuda . synchronize ()
1165+ sync ()
11151166 log0 (
11161167 f"final_int8_zlib_roundtrip val_loss:{ q_val_loss :.4f} val_bpb:{ q_val_bpb :.4f} "
11171168 f"eval_time:{ 1000.0 * (time .perf_counter () - t_qeval ):.0f} ms"
0 commit comments