Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions skyrl-tx/tests/models/lora_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,81 @@
"""Shared test utilities for LoRA training tests."""

import jax
import jax.numpy as jnp

from tx.utils.models import get_adapter_idx

def get_adapter_params(params, adapter_idx):
"""Extract adapter params at specific index."""
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)

def get_adapter_params(params, adapter_idx: int):
"""Extract adapter params at a specific index.

Decoder layer LoRA params have shape (num_layers, num_adapters, ...).
Embed tokens LoRA params have shape (num_adapters, ...).
"""

def extract(path, p):
idx = get_adapter_idx(path, adapter_idx)
return p[idx].copy()

return jax.tree.map_with_path(extract, params)

def get_out_of_rank_params(params, adapter_idx, rank):
"""Extract out-of-rank params for an adapter."""

def _slice_out_of_rank(params, adapter_idx: int, get_rank):
"""Extract out-of-rank params using a rank function.

Args:
params: LoRA parameters tree.
adapter_idx: Adapter index to extract.
get_rank: Function (path) -> int returning effective rank for that path.
"""

def slice_param(path, p):
if "lora_A" in str(path):
return p[adapter_idx, :, rank:].copy()
elif "lora_B" in str(path):
return p[adapter_idx, rank:, :].copy()
return p
path_str = str(path)
if "lora_A" not in path_str and "lora_B" not in path_str:
return p
rank = get_rank(path)
idx = get_adapter_idx(path, adapter_idx)
if "lora_A" in path_str:
return p[idx + (..., slice(rank, None))].copy()
return p[idx + (..., slice(rank, None), slice(None))].copy()

return jax.tree.map_with_path(slice_param, params)


def verify_params_unchanged(initial_params, final_params, error_msg_prefix):
"""Verify that params have not changed between initial and final states."""
def get_out_of_rank_params(params, adapter_idx: int, rank: int):
"""Extract out-of-rank params for an adapter."""
return _slice_out_of_rank(params, adapter_idx, lambda _: rank)


def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str):
"""Verify that params haven't changed between initial and final state."""
for (path, initial), (_, final) in zip(
jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params)
):
assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}"


def _is_routed_expert_path(path) -> bool:
"""Disambiguate shared_experts and experts"""
"""Check if path is for routed experts (not shared_experts)."""
keys = []
for p in path:
if hasattr(p, "key"):
keys.append(str(p.key))
elif hasattr(p, "name"):
keys.append(str(p.name))

for i, key in enumerate(keys):
if key == "experts" and i > 0 and keys[i - 1] == "mlp":
return True
return False


def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int):
"""Extract out-of-rank params, using effective rank for routed expert layers."""
"""Extract out-of-rank params for MoE models.

def slice_param(path, p):
path_str = str(path)

if _is_routed_expert_path(path):
effective_rank = max(1, rank // num_experts)
else:
effective_rank = rank
For routed experts, uses effective rank = max(1, rank // num_experts).
"""

if "lora_A" in path_str:
# lora_A shape: [adapters, ..., max_rank] - slice last dim
return p[adapter_idx, ..., effective_rank:].copy()
elif "lora_B" in path_str:
# lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim
return p[adapter_idx, ..., effective_rank:, :].copy()
return p
def get_rank(path):
return max(1, rank // num_experts) if _is_routed_expert_path(path) else rank

return jax.tree.map_with_path(slice_param, params)
return _slice_out_of_rank(params, adapter_idx, get_rank)
65 changes: 52 additions & 13 deletions skyrl-tx/tests/models/test_models_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from typing import Any

from flax import nnx
import jax
Expand All @@ -7,9 +8,10 @@
import pytest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from tx.models.configs import Llama3Config, Qwen3Config
from tx.models.configs import Llama3Config, ModelConfig, Qwen3Config
from tx.models.llama3 import Llama3ForCausalLM
from tx.models.qwen3 import Qwen3ForCausalLM
from tx.models.types import ModelForCausalLM
from tx.utils.models import load_safetensors

MODEL_PARAMS = [
Expand All @@ -19,26 +21,57 @@
MODEL_IDS = ["llama3", "qwen3"]


def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0):
"""Load model from pre-saved weights directory."""
def create_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
*,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelForCausalLM, ModelConfig]:
"""Create model with random weights for testing."""
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs)
# Default to Auto axis types to avoid sharding resolution errors
if mesh_axis_types is None:
mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=mesh_axis_types)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
return model, config


def load_model(
tmp_dir: str,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
*,
loss_chunk_size: int = 0,
) -> ModelForCausalLM:
"""Load model from pre-saved weights directory."""
model, config = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
loss_chunk_size=loss_chunk_size,
gradient_checkpointing=False,
)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp_dir, config, model)
return model


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
def test_compute_logits(model_name, config_cls, model_cls, mesh_axes):
def test_compute_logits(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
) -> None:
"""Test that model.compute_logits matches HuggingFace logits."""
tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand All @@ -65,7 +98,13 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes):

@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
@pytest.mark.parametrize("chunk_size", [8, 16, 32])
def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size):
def test_chunked_logprobs(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
chunk_size: int,
) -> None:
"""Test that chunked and non-chunked compute_logprobs produce identical results."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = ["The capital of France is", "Hello world"]
Expand Down
10 changes: 6 additions & 4 deletions skyrl-tx/tests/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,18 @@ def test_qwen3_lora():
)

# Load layer LoRA weights
for i, layer in enumerate(model.model.layers):
for i in range(config.num_hidden_layers):
hf_layer = hf_model.base_model.model.model.layers[i]
for module, projections in [
jax_layer = model.model.layers[i]
for module_name, projections in [
("mlp", ["gate_proj", "up_proj", "down_proj"]),
("self_attn", ["q_proj", "k_proj", "v_proj", "o_proj"]),
]:
for proj_name in projections:
hf_proj = getattr(getattr(hf_layer, module), proj_name)
hf_proj = getattr(getattr(hf_layer, module_name), proj_name)
jax_proj = getattr(getattr(jax_layer, module_name), proj_name)
load_lora_weights(
getattr(getattr(layer, module), proj_name),
jax_proj,
adapter_idx=adapter_idx,
lora_A_weights=hf_proj.lora_A["default"].weight.detach().numpy().T,
lora_B_weights=hf_proj.lora_B["default"].weight.detach().numpy().T,
Expand Down
101 changes: 101 additions & 0 deletions skyrl-tx/tests/utils/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from peft import PeftModel
from transformers import AutoConfig, AutoModelForCausalLM

from jax.tree_util import DictKey

from tx.layers.lora import init_lora_adapter
from tx.models.configs import Qwen3Config
from tx.models.qwen3 import Qwen3ForCausalLM
from tx.tinker.types import LoraConfig
from tx.utils import models
from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_path
from tx.utils.storage import download_and_unpack


Expand Down Expand Up @@ -82,3 +85,101 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat

assert torch.allclose(lora_A, torch.from_numpy(expected_lora_A), atol=1e-6)
assert torch.allclose(lora_B, torch.from_numpy(expected_lora_B), atol=1e-6)


@pytest.mark.parametrize(
"path,expected",
[
# Stacked paths (DictKey) — real NNX paths include _stacked
(
(
DictKey(key="model"),
DictKey(key="layers"),
DictKey(key="_stacked"),
DictKey(key="self_attn"),
DictKey(key="lora_A"),
),
True,
),
(
(
DictKey(key="model"),
DictKey(key="layers"),
DictKey(key="layer_groups"),
DictKey(key="_stacked"),
DictKey(key="self_attn"),
DictKey(key="lora_A"),
),
True,
),
# Non-stacked paths (DictKey)
((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False),
((DictKey(key="lm_head"), DictKey(key="lora_A")), False),
# String paths
(("model", "layers", "_stacked", "self_attn", "lora_A"), True),
(("model", "embed_tokens", "lora_A"), False),
],
ids=["stacked_layers", "multi_stacked_layers", "embed_tokens", "lm_head", "str_stacked", "str_embed"],
)
def test_is_stacked_path(path, expected):
"""Test is_stacked_path correctly identifies stacked vs non-stacked paths."""
assert is_stacked_path(path) is expected


def test_extract_insert_adapter_state_roundtrip():
"""Test that extract_adapter_state and insert_adapter_state are inverses."""
base_model_name = "Qwen/Qwen3-0.6B"
rank, alpha, adapter_index = 8, 16, 2
_, _, model = create_test_model(base_model_name, rank, alpha, adapter_index)

# Set LoRA weights to random values
q_proj = model.model.layers[0].self_attn.q_proj
rng1, rng2 = jax.random.split(jax.random.PRNGKey(123))
q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape)
q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape)

# Split model to get lora_params
_, lora_params, _ = nnx.split(model, model.is_lora_param, ...)

# Store original values for comparison
original_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank])
original_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :])

# Extract adapter state
extracted = extract_adapter_state(adapter_index, lora_params, rank)

# Verify extracted shape is correct (no adapter dimension)
for path, leaf in jax.tree.leaves_with_path(extracted):
key = path[-2].key if hasattr(path[-2], "key") else str(path[-2])
if key in {"lora_A", "lora_B"}:
# Stacked: should have (num_layers, ...) not (num_layers, num_adapters, ...)
if is_stacked_path(path):
assert leaf.shape[0] == 1 # num_layers
assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim)

# Zero out the adapter's weights
q_proj.lora_A[...] = q_proj.lora_A[...].at[adapter_index].set(0)
q_proj.lora_B[...] = q_proj.lora_B[...].at[adapter_index].set(0)

# Verify weights are zeroed
assert np.allclose(q_proj.lora_A[...][adapter_index], 0)
assert np.allclose(q_proj.lora_B[...][adapter_index], 0)

# Re-split to get updated lora_params
_, lora_params, _ = nnx.split(model, model.is_lora_param, ...)

# Insert extracted state back (modifies lora_params in-place via nnx.update)
insert_adapter_state(adapter_index, lora_params, extracted, rank)

# Verify weights are restored by checking lora_params directly
for path, leaf in jax.tree.leaves_with_path(lora_params):
key = path[-2].key if hasattr(path[-2], "key") else str(path[-2])
# leaf is a state wrapper with .value, or can be an array directly
arr = leaf.value if hasattr(leaf, "value") else leaf
if "q_proj" in str(path) and key == "lora_A":
restored_lora_A = np.array(arr[0, adapter_index, :, :rank])
elif "q_proj" in str(path) and key == "lora_B":
restored_lora_B = np.array(arr[0, adapter_index, :rank, :])

assert np.allclose(original_lora_A, restored_lora_A), "lora_A not restored correctly"
assert np.allclose(original_lora_B, restored_lora_B), "lora_B not restored correctly"
Loading
Loading