Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions skyrl-tx/tests/models/lora_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import jax
import jax.numpy as jnp

from tx.utils.models import get_adapter_idx


def get_adapter_params(params, adapter_idx):
"""Extract adapter params at specific index."""
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)

def extract(path, p):
idx = get_adapter_idx(path, adapter_idx)
return p[idx].copy()

return jax.tree.map_with_path(extract, params)


def get_out_of_rank_params(params, adapter_idx, rank):
"""Extract out-of-rank params for an adapter."""

def slice_param(path, p):
if "lora_A" in str(path):
return p[adapter_idx, :, rank:].copy()
elif "lora_B" in str(path):
return p[adapter_idx, rank:, :].copy()
path_str = str(path)
idx = get_adapter_idx(path, adapter_idx)
if "lora_A" in path_str:
return p[idx + (..., slice(rank, None))].copy()
elif "lora_B" in path_str:
return p[idx + (..., slice(rank, None), slice(None))].copy()
return p

return jax.tree.map_with_path(slice_param, params)
Expand Down Expand Up @@ -54,12 +63,13 @@ def slice_param(path, p):
else:
effective_rank = rank

idx = get_adapter_idx(path, adapter_idx)
if "lora_A" in path_str:
# lora_A shape: [adapters, ..., max_rank] - slice last dim
return p[adapter_idx, ..., effective_rank:].copy()
return p[idx + (..., slice(effective_rank, None))].copy()
elif "lora_B" in path_str:
# lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim
return p[adapter_idx, ..., effective_rank:, :].copy()
return p[idx + (..., slice(effective_rank, None), slice(None))].copy()
return p

return jax.tree.map_with_path(slice_param, params)
65 changes: 52 additions & 13 deletions skyrl-tx/tests/models/test_models_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from typing import Any

from flax import nnx
import jax
Expand All @@ -7,9 +8,10 @@
import pytest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from tx.models.configs import Llama3Config, Qwen3Config
from tx.models.configs import Llama3Config, ModelConfig, Qwen3Config
from tx.models.llama3 import Llama3ForCausalLM
from tx.models.qwen3 import Qwen3ForCausalLM
from tx.models.types import ModelForCausalLM
from tx.utils.models import load_safetensors

MODEL_PARAMS = [
Expand All @@ -19,26 +21,57 @@
MODEL_IDS = ["llama3", "qwen3"]


def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0):
"""Load model from pre-saved weights directory."""
def create_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
*,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelForCausalLM, ModelConfig]:
"""Create model with random weights for testing."""
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs)
# Default to Auto axis types to avoid sharding resolution errors
if mesh_axis_types is None:
mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=mesh_axis_types)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
return model, config


def load_model(
tmp_dir: str,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
*,
loss_chunk_size: int = 0,
) -> ModelForCausalLM:
"""Load model from pre-saved weights directory."""
model, config = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
loss_chunk_size=loss_chunk_size,
gradient_checkpointing=False,
)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp_dir, config, model)
return model


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
def test_compute_logits(model_name, config_cls, model_cls, mesh_axes):
def test_compute_logits(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
) -> None:
"""Test that model.compute_logits matches HuggingFace logits."""
tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand All @@ -65,7 +98,13 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes):

@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
@pytest.mark.parametrize("chunk_size", [8, 16, 32])
def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size):
def test_chunked_logprobs(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
chunk_size: int,
) -> None:
"""Test that chunked and non-chunked compute_logprobs produce identical results."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = ["The capital of France is", "Hello world"]
Expand Down
10 changes: 6 additions & 4 deletions skyrl-tx/tests/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,18 @@ def test_qwen3_lora():
)

# Load layer LoRA weights
for i, layer in enumerate(model.model.layers):
for i in range(config.num_hidden_layers):
hf_layer = hf_model.base_model.model.model.layers[i]
for module, projections in [
jax_layer = model.model.layers[i]
for module_name, projections in [
("mlp", ["gate_proj", "up_proj", "down_proj"]),
("self_attn", ["q_proj", "k_proj", "v_proj", "o_proj"]),
]:
for proj_name in projections:
hf_proj = getattr(getattr(hf_layer, module), proj_name)
hf_proj = getattr(getattr(hf_layer, module_name), proj_name)
jax_proj = getattr(getattr(jax_layer, module_name), proj_name)
load_lora_weights(
getattr(getattr(layer, module), proj_name),
jax_proj,
adapter_idx=adapter_idx,
lora_A_weights=hf_proj.lora_A["default"].weight.detach().numpy().T,
lora_B_weights=hf_proj.lora_B["default"].weight.detach().numpy().T,
Expand Down
101 changes: 101 additions & 0 deletions skyrl-tx/tests/utils/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from peft import PeftModel
from transformers import AutoConfig, AutoModelForCausalLM

from jax.tree_util import DictKey

from tx.layers.lora import init_lora_adapter
from tx.models.configs import Qwen3Config
from tx.models.qwen3 import Qwen3ForCausalLM
from tx.tinker.types import LoraConfig
from tx.utils import models
from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_path
from tx.utils.storage import download_and_unpack


Expand Down Expand Up @@ -82,3 +85,101 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat

assert torch.allclose(lora_A, torch.from_numpy(expected_lora_A), atol=1e-6)
assert torch.allclose(lora_B, torch.from_numpy(expected_lora_B), atol=1e-6)


@pytest.mark.parametrize(
"path,expected",
[
# Stacked paths (DictKey) — real NNX paths include _stacked
(
(
DictKey(key="model"),
DictKey(key="layers"),
DictKey(key="_stacked"),
DictKey(key="self_attn"),
DictKey(key="lora_A"),
),
True,
),
(
(
DictKey(key="model"),
DictKey(key="layers"),
DictKey(key="layer_groups"),
DictKey(key="_stacked"),
DictKey(key="self_attn"),
DictKey(key="lora_A"),
),
True,
),
# Non-stacked paths (DictKey)
((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False),
((DictKey(key="lm_head"), DictKey(key="lora_A")), False),
# String paths
(("model", "layers", "_stacked", "self_attn", "lora_A"), True),
(("model", "embed_tokens", "lora_A"), False),
],
ids=["stacked_layers", "multi_stacked_layers", "embed_tokens", "lm_head", "str_stacked", "str_embed"],
)
def test_is_stacked_path(path, expected):
"""Test is_stacked_path correctly identifies stacked vs non-stacked paths."""
assert is_stacked_path(path) is expected


def test_extract_insert_adapter_state_roundtrip():
"""Test that extract_adapter_state and insert_adapter_state are inverses."""
base_model_name = "Qwen/Qwen3-0.6B"
rank, alpha, adapter_index = 8, 16, 2
_, _, model = create_test_model(base_model_name, rank, alpha, adapter_index)

# Set LoRA weights to random values
q_proj = model.model.layers[0].self_attn.q_proj
rng1, rng2 = jax.random.split(jax.random.PRNGKey(123))
q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape)
q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape)

# Split model to get lora_params
_, lora_params, _ = nnx.split(model, model.is_lora_param, ...)

# Store original values for comparison
original_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank])
original_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :])

# Extract adapter state
extracted = extract_adapter_state(adapter_index, lora_params, rank)

# Verify extracted shape is correct (no adapter dimension)
for path, leaf in jax.tree.leaves_with_path(extracted):
key = path[-2].key if hasattr(path[-2], "key") else str(path[-2])
if key in {"lora_A", "lora_B"}:
# Stacked: should have (num_layers, ...) not (num_layers, num_adapters, ...)
if is_stacked_path(path):
assert leaf.shape[0] == 1 # num_layers
assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim)

# Zero out the adapter's weights
q_proj.lora_A[...] = q_proj.lora_A[...].at[adapter_index].set(0)
q_proj.lora_B[...] = q_proj.lora_B[...].at[adapter_index].set(0)

# Verify weights are zeroed
assert np.allclose(q_proj.lora_A[...][adapter_index], 0)
assert np.allclose(q_proj.lora_B[...][adapter_index], 0)

# Re-split to get updated lora_params
_, lora_params, _ = nnx.split(model, model.is_lora_param, ...)

# Insert extracted state back (modifies lora_params in-place via nnx.update)
insert_adapter_state(adapter_index, lora_params, extracted, rank)

# Verify weights are restored by checking lora_params directly
for path, leaf in jax.tree.leaves_with_path(lora_params):
key = path[-2].key if hasattr(path[-2], "key") else str(path[-2])
# leaf is a state wrapper with .value, or can be an array directly
arr = leaf.value if hasattr(leaf, "value") else leaf
if "q_proj" in str(path) and key == "lora_A":
restored_lora_A = np.array(arr[0, adapter_index, :, :rank])
elif "q_proj" in str(path) and key == "lora_B":
restored_lora_B = np.array(arr[0, adapter_index, :rank, :])

assert np.allclose(original_lora_A, restored_lora_A), "lora_A not restored correctly"
assert np.allclose(original_lora_B, restored_lora_B), "lora_B not restored correctly"
26 changes: 13 additions & 13 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
from jax import numpy as jnp

from tx.utils.models import filter_lora
from tx.utils.models import filter_lora, get_adapter_idx
from tx.layers.util import Param, prepare_routing, ragged_dot
from tx.models.types import ModelForCausalLM
from tx.tinker.types import LoraConfig
Expand Down Expand Up @@ -345,21 +345,22 @@ def init_adapter(path, value):
if not filter_lora(lora_config, normalized_path):
effective_rank = 0

idx = get_adapter_idx(path, adapter_index)

key_name = path[-2].key
if key_name == "lora_ranks":
return value.at[adapter_index].set(effective_rank)
return value.at[idx].set(effective_rank)
if key_name == "lora_scaling":
# Set scaling to 0.0 if rank is 0
return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0)
scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0
return value.at[idx].set(scaling)
if key_name == "lora_A":
# Reinitialize with he_uniform, then zero columns beyond rank
shape = value[adapter_index].shape
new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype)
new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype)
new_A = new_A.at[..., effective_rank:].set(0.0)
return value.at[adapter_index].set(new_A)
return value.at[idx].set(new_A)
if key_name == "lora_B":
# Explicitly zero lora_B
return value.at[adapter_index].set(0.0)
return value.at[idx].set(0.0)
return value

updated_state = jax.tree.map_with_path(init_adapter, state)
Expand All @@ -376,11 +377,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):

def clear_adapter(path, value):
key = path[-2].key
if key == "lora_ranks":
return value.at[adapter_index].set(0)
if key in ("lora_scaling", "lora_A", "lora_B"):
return value.at[adapter_index].set(0.0)
return value
if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
return value
idx = get_adapter_idx(path, adapter_index)
return value.at[idx].set(0 if key == "lora_ranks" else 0.0)

updated_state = jax.tree.map_with_path(clear_adapter, state)
nnx.update(model, updated_state)
Loading
Loading