[tx] Stack weights — Qwen3#1079
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed refactoring to use stacked weights for transformer layers, which is a great optimization for performance and memory. The new StackedDecoderLayers and ArrayRef abstractions are cleverly implemented to achieve this while maintaining compatibility with existing checkpointing and LoRA operations. The refactoring of test utilities into a shared file is also a good improvement.
I've found a critical issue in the unstack_paths implementation that would prevent checkpointing from working correctly with the new stacked layers. I've also included a suggestion to make the unstack_state function more generic and maintainable. Addressing these points will make this excellent contribution even more robust.
| if key not in tensors: | ||
| continue |
There was a problem hiding this comment.
🔴 Expert weights silently skipped due to premature key not in tensors check in load_safetensors
Expert (MoE) weights are never loaded because the key not in tensors guard at line 167 runs before the expert-specific key construction at line 169-170.
Root Cause
For expert parameters, the path tuple contains "experts" as a single element (e.g., ("model", "layers", "0", "mlp", "experts", "gate_proj", "weight")). The get_param_key function joins this into "model.layers.0.mlp.experts.gate_proj.weight". However, the actual keys in the safetensors checkpoint file are per-expert: "model.layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.mlp.experts.1.gate_proj.weight", etc.
The new code at line 167 checks if key not in tensors: continue which evaluates to True for expert paths (the generic key doesn't exist in the tensor dict), causing expert weights to be silently skipped. The expert-specific key construction via get_expert_key at line 170 is never reached.
The old code did not have this guard (git show HEAD~5:skyrl-tx/tx/utils/models.py), so expert weights were loaded correctly before this PR.
Impact: All MoE model weights (e.g., Qwen3 MoE, DeepseekV3) fail to load expert parameters, resulting in models running with randomly initialized expert weights. This would produce garbage outputs for any MoE model.
| if key not in tensors: | |
| continue | |
| if key not in tensors and "experts" not in path: | |
| continue | |
Was this helpful? React with 👍 or 👎 to provide feedback.
Add stacked decoder layer infrastructure that stores transformer layer weights in (num_layers, ...) format for efficient jax.lax.scan-based forward passes. - tx/layers/stacked.py: StackedDecoderLayers, ArrayRef, unstack_state() - tx/layers/util.py: Fix shard_map_ep to truncate PartitionSpec for stacked layers - tx/utils/generator.py: Add KVCache.num_layers, .batch_size, .seq_len properties - tx/models/types.py: Allow kv_cache to be None (during training) - tx/models/configs.py: Update gradient_checkpointing docstring Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adapt weight loading/saving and LoRA operations to work with stacked layer format using unstack_state() and ArrayRef write-through. - tx/utils/models.py: Add is_stacked_path, get_adapter_idx, get_lora_adapter_slice. Rewrite load/save_safetensors to use unstack_state(). Simplify lora checkpoint I/O. - tx/layers/lora.py: init/clear_lora_adapter use get_adapter_idx for stacked indexing - tx/utils/storage.py: Remove rank param from pack_and_upload, use jax.process_index() - tx/tinker/backends/jax.py: AccumulatedGradients uses get_adapter_idx with map_with_path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace nnx.List with StackedDecoderLayers for Qwen3 model. Forward pass now uses self.layers() (scan-based) instead of a Python for-loop. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- tests/models/lora_test_utils.py: Add shared LoRA test utilities - test_qwen3.py: Fix layer iteration for stacked layers - test_qwen3_lora_training.py: Fix import ordering - test_models_common.py: Add create_model helper, type hints - test_models.py: Add test_is_stacked_path, test_extract_insert_adapter_state_roundtrip Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
543cdc1 to
f228229
Compare
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
There was a problem hiding this comment.
🔴 Training forward pass doesn't set is_training=True, causing unnecessary KV cache accumulation in scan
The _model_forward function in the JAX backend calls model(...) without is_training=True during training. This means the scan body in StackedDecoderLayers accumulates KV cache arrays for all layers, wasting significant memory.
Detailed Explanation
At tx/tinker/backends/jax.py:274, the training forward pass calls:
output = model(
input_ids,
attention_mask=attention_mask,
adapter_indices=adapter_indices,
)Since is_training defaults to False in Qwen3ForCausalLM.__call__ (tx/models/qwen3.py:419), the scan body in StackedDecoderLayers.__call__ (tx/layers/stacked.py:281-282) does NOT set k = v = None, causing full KV cache tensors to be accumulated through the scan as secondary outputs.
The is_training=True flag was specifically added in this PR to skip KV accumulation during training (tx/layers/stacked.py:280-282), but it is never actually used in the training path.
Impact: During training, the model unnecessarily stores KV cache arrays for all layers, consuming significant additional memory. For large models, this can lead to OOM errors or require smaller batch sizes than necessary.
(Refers to lines 274-278)
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
🔴 Training forward pass does not set is_training=True, causing unnecessary KV cache allocation in scan
The _model_forward function in the JAX backend calls model(input_ids, attention_mask=..., adapter_indices=...) without passing is_training=True. Since this PR converts the Qwen3 model from a Python for-loop to jax.lax.scan, the scan body now accumulates KV cache arrays for all layers as stacked scan outputs when is_training is False.
Root cause and memory impact
In tx/layers/stacked.py:281-283, the scan body only skips KV accumulation when is_training is True:
if is_training:
k = v = NoneWithout is_training=True, the scan allocates and returns stacked key and value arrays for all layers (tx/layers/stacked.py:289-297), and then creates a full KVCache object that is never used by the training pipeline. For large models, this means allocating 2 * num_layers * batch * seq_len * num_kv_heads * head_dim floats of GPU memory unnecessarily, which could cause OOM during training.
The is_training parameter was specifically added by this PR to Qwen3ForCausalLM.__call__ (tx/models/qwen3.py:419) and Qwen3Model.__call__ (tx/models/qwen3.py:349) for this purpose, but the backend at tx/tinker/backends/jax.py:274-278 was not updated to pass it.
(Refers to lines 274-278)
Was this helpful? React with 👍 or 👎 to provide feedback.
| if key not in tensors: | ||
| continue |
There was a problem hiding this comment.
🔴 Early key not in tensors check silently skips all expert weight loading for MoE models
The newly added if key not in tensors: continue guard causes all expert weights to be silently skipped during checkpoint loading for MoE models (e.g., Qwen3-MoE).
Root Cause
For expert params, get_param_key at tx/utils/models.py:121-127 produces a key like model.layers.0.mlp.experts.gate_proj.weight (without per-expert index). But HuggingFace checkpoint files store expert weights with per-expert keys like model.layers.0.mlp.experts.0.gate_proj.weight.
The old code did not have this early-exit check, so it would always reach the if "experts" in path: block (line 169) which correctly uses get_expert_key to look up each expert individually. The new key not in tensors check at line 167 short-circuits before that expert-handling logic ever runs.
Impact: All expert base weights in MoE models are silently skipped during loading, resulting in zeros/random weights for all expert layers. The model would produce incorrect outputs. This is a regression from the old code.
Prompt for agents
In load_safetensors (tx/utils/models.py), the check at line 167-168 'if key not in tensors: continue' must be moved AFTER the expert-handling block. For expert paths, the key from get_param_key does not exist directly in the tensors dict — instead, per-expert keys are looked up via get_expert_key inside the 'if experts in path' block on line 169. The fix is to reorder the logic: first check for experts (which constructs the tensor from per-expert keys), then for non-expert paths check if key exists in tensors. Something like:
if "experts" in path:
expert_key = get_expert_key(path, 0)
if expert_key not in tensors:
continue
tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
elif key not in tensors:
continue
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
Was this helpful? React with 👍 or 👎 to provide feedback.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed infrastructure for stacking transformer layer weights using StackedDecoderLayers. This change, which replaces nnx.List with a jax.lax.scan-based approach for the forward pass, is a great step towards improving performance, especially for training and prefill. The use of ArrayRef to maintain compatibility with existing checkpointing and LoRA operations is particularly clever.
The refactoring of the Qwen3 model and various utility functions to support this new stacked format is thorough. However, I've found a critical and repeated bug in tx/utils/models.py related to how NNX state paths are processed. The checks for keys within these paths are incorrect, which will cause LoRA weight loading and saving to fail. I've left detailed comments on this issue. Once that is addressed, this PR will be in excellent shape.
8c5507c to
743c82b
Compare
743c82b to
8c5507c
Compare
## Summary - Convert Llama3 model from `nnx.List` + Python for-loop to `StackedDecoderLayers` + scan-based forward pass - Same pattern as the Qwen3 conversion in #1079 Stacks on #1079. ## Test plan - [x] `pytest tests/models/test_llama3.py tests/models/test_llama3_lora_training.py -x` <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1081" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
## Summary - Convert Llama3 model from `nnx.List` + Python for-loop to `StackedDecoderLayers` + scan-based forward pass - Same pattern as the Qwen3 conversion in NovaSky-AI#1079 Stacks on NovaSky-AI#1079. ## Test plan - [x] `pytest tests/models/test_llama3.py tests/models/test_llama3_lora_training.py -x` <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1081" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
See #1079 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1127" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
stacks on #1078
Summary
StackedDecoderLayersinfrastructure that stores transformer layer weights in(num_layers, ...)format for efficientjax.lax.scan-based forward passes (tx/layers/stacked.py)unstack_state()+ArrayRefwrite-throughnnx.List+ Python for-loop toStackedDecoderLayers+ scanKey concepts:
jax.lax.scan(prefill/training)or Python loop (decode, for KV cache donation)
checkpoint and LoRA code works without changes
layers._stacked.xxx) to per-layer paths (layers.0.xxx) with ArrayRef forcheckpoint compatibility