Skip to content

[tx] Stack weights — Qwen3#1079

Merged
pcmoritz merged 20 commits intoNovaSky-AI:mainfrom
raulchen:stack-weights-stack-qwen3
Feb 12, 2026
Merged

[tx] Stack weights — Qwen3#1079
pcmoritz merged 20 commits intoNovaSky-AI:mainfrom
raulchen:stack-weights-stack-qwen3

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Feb 11, 2026

stacks on #1078

Summary

  • Add StackedDecoderLayers infrastructure that stores transformer layer weights in (num_layers, ...) format for efficient
    jax.lax.scan-based forward passes (tx/layers/stacked.py)
  • Adapt checkpoint I/O and LoRA operations to work with stacked format via unstack_state() + ArrayRef write-through
  • Convert Qwen3 model from nnx.List + Python for-loop to StackedDecoderLayers + scan
  • It's needed for the per-layer gradient checkpointing optimization.

Key concepts:

  • StackedDecoderLayers: Stores all layer weights as stacked arrays and runs forward via jax.lax.scan (prefill/training)
    or Python loop (decode, for KV cache donation)
  • ArrayRef: NNX Variable subclass providing write-through views into indexed slices of stacked arrays, so existing
    checkpoint and LoRA code works without changes
  • unstack_state(): Converts stacked paths (layers._stacked.xxx) to per-layer paths (layers.0.xxx) with ArrayRef for
    checkpoint compatibility

Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-designed refactoring to use stacked weights for transformer layers, which is a great optimization for performance and memory. The new StackedDecoderLayers and ArrayRef abstractions are cleverly implemented to achieve this while maintaining compatibility with existing checkpointing and LoRA operations. The refactoring of test utilities into a shared file is also a good improvement.

I've found a critical issue in the unstack_paths implementation that would prevent checkpointing from working correctly with the new stacked layers. I've also included a suggestion to make the unstack_state function more generic and maintainable. Addressing these points will make this excellent contribution even more robust.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 8 additional findings in Devin Review.

Open in Devin Review

Comment on lines +167 to +168
if key not in tensors:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Expert weights silently skipped due to premature key not in tensors check in load_safetensors

Expert (MoE) weights are never loaded because the key not in tensors guard at line 167 runs before the expert-specific key construction at line 169-170.

Root Cause

For expert parameters, the path tuple contains "experts" as a single element (e.g., ("model", "layers", "0", "mlp", "experts", "gate_proj", "weight")). The get_param_key function joins this into "model.layers.0.mlp.experts.gate_proj.weight". However, the actual keys in the safetensors checkpoint file are per-expert: "model.layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.mlp.experts.1.gate_proj.weight", etc.

The new code at line 167 checks if key not in tensors: continue which evaluates to True for expert paths (the generic key doesn't exist in the tensor dict), causing expert weights to be silently skipped. The expert-specific key construction via get_expert_key at line 170 is never reached.

The old code did not have this guard (git show HEAD~5:skyrl-tx/tx/utils/models.py), so expert weights were loaded correctly before this PR.

Impact: All MoE model weights (e.g., Qwen3 MoE, DeepseekV3) fail to load expert parameters, resulting in models running with randomly initialized expert weights. This would produce garbage outputs for any MoE model.

Suggested change
if key not in tensors:
continue
if key not in tensors and "experts" not in path:
continue
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

raulchen and others added 4 commits February 11, 2026 15:39
Add stacked decoder layer infrastructure that stores transformer layer
weights in (num_layers, ...) format for efficient jax.lax.scan-based
forward passes.

- tx/layers/stacked.py: StackedDecoderLayers, ArrayRef, unstack_state()
- tx/layers/util.py: Fix shard_map_ep to truncate PartitionSpec for stacked layers
- tx/utils/generator.py: Add KVCache.num_layers, .batch_size, .seq_len properties
- tx/models/types.py: Allow kv_cache to be None (during training)
- tx/models/configs.py: Update gradient_checkpointing docstring

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adapt weight loading/saving and LoRA operations to work with stacked
layer format using unstack_state() and ArrayRef write-through.

- tx/utils/models.py: Add is_stacked_path, get_adapter_idx, get_lora_adapter_slice.
  Rewrite load/save_safetensors to use unstack_state(). Simplify lora checkpoint I/O.
- tx/layers/lora.py: init/clear_lora_adapter use get_adapter_idx for stacked indexing
- tx/utils/storage.py: Remove rank param from pack_and_upload, use jax.process_index()
- tx/tinker/backends/jax.py: AccumulatedGradients uses get_adapter_idx with map_with_path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace nnx.List with StackedDecoderLayers for Qwen3 model. Forward pass
now uses self.layers() (scan-based) instead of a Python for-loop.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- tests/models/lora_test_utils.py: Add shared LoRA test utilities
- test_qwen3.py: Fix layer iteration for stacked layers
- test_qwen3_lora_training.py: Fix import ordering
- test_models_common.py: Add create_model helper, type hints
- test_models.py: Add test_is_stacked_path, test_extract_insert_adapter_state_roundtrip

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@raulchen raulchen force-pushed the stack-weights-stack-qwen3 branch from 543cdc1 to f228229 Compare February 11, 2026 23:39
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 13 additional findings in Devin Review.

Open in Devin Review

pcmoritz and others added 4 commits February 12, 2026 01:08
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 19 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Training forward pass doesn't set is_training=True, causing unnecessary KV cache accumulation in scan

The _model_forward function in the JAX backend calls model(...) without is_training=True during training. This means the scan body in StackedDecoderLayers accumulates KV cache arrays for all layers, wasting significant memory.

Detailed Explanation

At tx/tinker/backends/jax.py:274, the training forward pass calls:

output = model(
    input_ids,
    attention_mask=attention_mask,
    adapter_indices=adapter_indices,
)

Since is_training defaults to False in Qwen3ForCausalLM.__call__ (tx/models/qwen3.py:419), the scan body in StackedDecoderLayers.__call__ (tx/layers/stacked.py:281-282) does NOT set k = v = None, causing full KV cache tensors to be accumulated through the scan as secondary outputs.

The is_training=True flag was specifically added in this PR to skip KV accumulation during training (tx/layers/stacked.py:280-282), but it is never actually used in the training path.

Impact: During training, the model unnecessarily stores KV cache arrays for all layers, consuming significant additional memory. For large models, this can lead to OOM errors or require smaller batch sizes than necessary.

(Refers to lines 274-278)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 7 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Training forward pass does not set is_training=True, causing unnecessary KV cache allocation in scan

The _model_forward function in the JAX backend calls model(input_ids, attention_mask=..., adapter_indices=...) without passing is_training=True. Since this PR converts the Qwen3 model from a Python for-loop to jax.lax.scan, the scan body now accumulates KV cache arrays for all layers as stacked scan outputs when is_training is False.

Root cause and memory impact

In tx/layers/stacked.py:281-283, the scan body only skips KV accumulation when is_training is True:

if is_training:
    k = v = None

Without is_training=True, the scan allocates and returns stacked key and value arrays for all layers (tx/layers/stacked.py:289-297), and then creates a full KVCache object that is never used by the training pipeline. For large models, this means allocating 2 * num_layers * batch * seq_len * num_kv_heads * head_dim floats of GPU memory unnecessarily, which could cause OOM during training.

The is_training parameter was specifically added by this PR to Qwen3ForCausalLM.__call__ (tx/models/qwen3.py:419) and Qwen3Model.__call__ (tx/models/qwen3.py:349) for this purpose, but the backend at tx/tinker/backends/jax.py:274-278 was not updated to pass it.

(Refers to lines 274-278)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 9 additional findings in Devin Review.

Open in Devin Review

Comment on lines +167 to +168
if key not in tensors:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Early key not in tensors check silently skips all expert weight loading for MoE models

The newly added if key not in tensors: continue guard causes all expert weights to be silently skipped during checkpoint loading for MoE models (e.g., Qwen3-MoE).

Root Cause

For expert params, get_param_key at tx/utils/models.py:121-127 produces a key like model.layers.0.mlp.experts.gate_proj.weight (without per-expert index). But HuggingFace checkpoint files store expert weights with per-expert keys like model.layers.0.mlp.experts.0.gate_proj.weight.

The old code did not have this early-exit check, so it would always reach the if "experts" in path: block (line 169) which correctly uses get_expert_key to look up each expert individually. The new key not in tensors check at line 167 short-circuits before that expert-handling logic ever runs.

Impact: All expert base weights in MoE models are silently skipped during loading, resulting in zeros/random weights for all expert layers. The model would produce incorrect outputs. This is a regression from the old code.

Prompt for agents
In load_safetensors (tx/utils/models.py), the check at line 167-168 'if key not in tensors: continue' must be moved AFTER the expert-handling block. For expert paths, the key from get_param_key does not exist directly in the tensors dict — instead, per-expert keys are looked up via get_expert_key inside the 'if experts in path' block on line 169. The fix is to reorder the logic: first check for experts (which constructs the tensor from per-expert keys), then for non-expert paths check if key exists in tensors. Something like:

if "experts" in path:
    expert_key = get_expert_key(path, 0)
    if expert_key not in tensors:
        continue
    tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
elif key not in tensors:
    continue
else:
    tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-designed infrastructure for stacking transformer layer weights using StackedDecoderLayers. This change, which replaces nnx.List with a jax.lax.scan-based approach for the forward pass, is a great step towards improving performance, especially for training and prefill. The use of ArrayRef to maintain compatibility with existing checkpointing and LoRA operations is particularly clever.

The refactoring of the Qwen3 model and various utility functions to support this new stacked format is thorough. However, I've found a critical and repeated bug in tx/utils/models.py related to how NNX state paths are processed. The checks for keys within these paths are incorrect, which will cause LoRA weight loading and saving to fail. I've left detailed comments on this issue. Once that is addressed, this PR will be in excellent shape.

@raulchen raulchen force-pushed the stack-weights-stack-qwen3 branch from 8c5507c to 743c82b Compare February 12, 2026 17:58
@pcmoritz pcmoritz force-pushed the stack-weights-stack-qwen3 branch from 743c82b to 8c5507c Compare February 12, 2026 18:15
@pcmoritz pcmoritz merged commit c7e5c51 into NovaSky-AI:main Feb 12, 2026
3 of 5 checks passed
@raulchen raulchen deleted the stack-weights-stack-qwen3 branch February 12, 2026 22:38
pcmoritz added a commit that referenced this pull request Feb 14, 2026
## Summary
- Convert Llama3 model from `nnx.List` + Python for-loop to
`StackedDecoderLayers` + scan-based forward pass
- Same pattern as the Qwen3 conversion in #1079

Stacks on #1079.

## Test plan
- [x] `pytest tests/models/test_llama3.py
tests/models/test_llama3_lora_training.py -x`
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1081"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Feb 14, 2026
## Summary
- Convert Llama3 model from `nnx.List` + Python for-loop to
`StackedDecoderLayers` + scan-based forward pass
- Same pattern as the Qwen3 conversion in NovaSky-AI#1079

Stacks on NovaSky-AI#1079.

## Test plan
- [x] `pytest tests/models/test_llama3.py
tests/models/test_llama3_lora_training.py -x`
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1081"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
pcmoritz added a commit that referenced this pull request Feb 15, 2026
See #1079
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1127"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants