[WIP][skyrl] Create new skyrl folder combining tx + train#1068
[WIP][skyrl] Create new skyrl folder combining tx + train#1068erictang000 merged 19 commits intoNovaSky-AI:mainfrom
Conversation
update build_models in the SkyRLTrainBackend to only include the supported colocate all + policy model only logic for cleanliness. cc: @pcmoritz <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1065" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This was from an earlier iteration and is not used any more, now the inference is done through https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-tx/tx/tinker/backends/skyrl_train.py <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1066" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
NovaSky-AI#1067) Change the training backends to the desired state of "jax", "fsdp", "megatron". Adds megatron deps to pyproject.toml via skyrl-train[mcore] dependency, and add needed overrides to get installation working. For FSDP + vLLM: ```bash uv run --isolated --extra tinker -m tx.tinker.api --base-model "Qwen/Qwen3-0.6B" --backend fsdp ``` For Megatron + vLLM: ```bash uv run --isolated --extra tinker -m tx.tinker.api --base-model "Qwen/Qwen3-0.6B" --backend megatron ``` <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1067" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
There was a problem hiding this comment.
Code Review
This pull request introduces a major refactoring by creating a new unified skyrl package, combining skyrl-tx and skyrl-train functionality, and includes a new FastAPI-based 'Tinker API Mock'. The new structure is well-organized with clear separation of concerns, featuring a tinker engine, tx model implementations, run scripts, and an abstract backend interface for enhanced modularity. However, a security audit identified several high and medium severity issues, primarily related to insecure handling of sensitive information in logs, potential Denial of Service via memory exhaustion during large file downloads, and broken access control due to guessable identifiers and lack of authentication on sensitive endpoints. Additionally, critical code issues include an accidentally committed log file and a non-reproducible git dependency in pyproject.toml. Further improvements are also needed in error handling, the use of internal APIs, and splitting the large api.py file into smaller modules for better maintainability.
skyrl/skyrl/run/main.py
Outdated
| @@ -0,0 +1,17 @@ | |||
| import typer | |||
There was a problem hiding this comment.
I suggest we just don't bring the whole run folder over (main.py and train.py). That is old code before there was the tinker API, and I think it is just confusing to have it in the skyrl folder :)
If you just remove it from this PR, I think everything should still work normally, but do let me know if we need to make some modification to make it work, happy to make that in the original folder to reduce conflicts / problems :)
Otherwise, I think this transition is the right place to get rid of those files.
|
Make sure to not commit your database files to the main repo or else everybody will have to pay the overhead going forward even if we remove them (there is kind of a painful way to purge them but let's not let it come to that). |
eaf9e92 to
8367759
Compare
skyrl/skyrl/backends/skyrl_train.py
Outdated
| padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) | ||
| else: | ||
| # Clone real data so shapes/dtypes are valid for the model | ||
| padding_tensor = tensor[:pad_size].clone() |
There was a problem hiding this comment.
🔴 _pad_batch produces wrong padding size when pad_size > batch_size
When the batch size is smaller than dp_size - 1, the padding logic in _pad_batch produces incorrect results because tensor[:pad_size].clone() returns fewer elements than pad_size when pad_size > batch_size. This causes inconsistent tensor sizes across batch keys.
Root Cause and Impact
At skyrl/skyrl/backends/skyrl_train.py:316, padding for non-loss_mask keys is done via:
padding_tensor = tensor[:pad_size].clone()When pad_size > tensor.shape[0] (i.e., pad_size > batch_size), PyTorch's slicing returns only batch_size elements, not pad_size. Meanwhile, loss_mask at line 313 correctly creates a torch.zeros(pad_size, ...) tensor of the right size.
For example, with batch_size=1 and dp_size=8:
pad_size = ceil(1/8)*8 - 1 = 7loss_maskgetstorch.zeros(7, ...)→ final size = 1 + 7 = 8 ✓sequencesgetstensor[:7].clone()→ only 1 element → final size = 1 + 1 = 2 ✗
This produces a TrainingInputBatch with mismatched tensor sizes (e.g., loss_mask has 8 rows but sequences has 2), which will crash when the dispatch layer tries to split the batch across DP workers. Any batch with fewer examples than dp_size - 1 triggers this bug.
| padding_tensor = tensor[:pad_size].clone() | |
| repeat_count = (pad_size + tensor.shape[0] - 1) // tensor.shape[0] | |
| padding_tensor = tensor.repeat((repeat_count,) + (1,) * (tensor.ndim - 1))[:pad_size].clone() | |
Was this helpful? React with 👍 or 👎 to provide feedback.
|
Diff between skyrl-tx/ and skyrl/ (mostly) normalized for differences that are just renaming tx -> skyrl using command rm -rf /tmp/diff_tx /tmp/diff_skyrl && \
mkdir -p /tmp/diff_tx /tmp/diff_skyrl && \
rsync -a --exclude='__pycache__' --exclude='*.db*' --exclude='uv.lock' --exclude='.git' SkyRL/skyrl-tx/tx/ /tmp/diff_tx/ && \
rsync -a --exclude='__pycache__' --exclude='*.db*' --exclude='uv.lock' --exclude='.git' SkyRL/skyrl/skyrl/ /tmp/diff_skyrl/ && \
find /tmp/diff_tx -name '*.py' -exec sed -i 's/from tx\./from skyrl./g; s/import tx\./import skyrl./g; s/import tx$/import skyrl/g' {} + && \
diff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || truediff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || true
Only in /tmp/diff_skyrl/: backends
Only in /tmp/diff_tx/: layers
Only in /tmp/diff_tx/: loaders
Only in /tmp/diff_tx/: models
Only in /tmp/diff_tx/: run
diff -ru /tmp/diff_tx/tinker/alembic/env.py /tmp/diff_skyrl/tinker/alembic/env.py
--- /tmp/diff_tx/tinker/alembic/env.py 2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/alembic/env.py 2026-02-11 20:08:43.698536673 +0000
@@ -7,7 +7,7 @@
from alembic import context
-# Add parent directory to path so we can import tx modules
+# Add parent directory to path so we can import skyrl modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
# Import SQLModel and database models
@@ -67,7 +67,7 @@
from sqlalchemy import create_engine
# Get database URL - ignore whatever is in config, use our helper
- db_url = os.environ["TX_DATABASE_URL"]
+ db_url = os.environ["SKYRL_DATABASE_URL"]
connectable = create_engine(db_url, poolclass=pool.NullPool)
with connectable.connect() as connection:
diff -ru /tmp/diff_tx/tinker/api.py /tmp/diff_skyrl/tinker/api.py
--- /tmp/diff_tx/tinker/api.py 2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/api.py 2026-02-11 20:08:43.698536673 +0000
@@ -71,7 +71,7 @@
"--extra",
app.state.engine_config.backend,
"-m",
- "tx.tinker.engine",
+ "skyrl.tinker.engine",
]
cmd.extend(config_to_argv(app.state.engine_config))
@@ -1177,7 +1177,7 @@
import uvicorn
# Parse command-line arguments
- parser = argparse.ArgumentParser(description="SkyRL tx tinker API server")
+ parser = argparse.ArgumentParser(description="SkyRL tinker API server")
add_model(parser, EngineConfig)
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
Only in /tmp/diff_tx/tinker: backends
diff -ru /tmp/diff_tx/tinker/config.py /tmp/diff_skyrl/tinker/config.py
--- /tmp/diff_tx/tinker/config.py 2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/config.py 2026-02-11 20:08:43.698536673 +0000
@@ -20,13 +20,13 @@
json_schema_extra={"argparse_type": json.loads},
)
checkpoints_base: AnyPath = Field(
- default=AnyPath("/tmp/tx_checkpoints"),
+ default=AnyPath("/tmp/skyrl_checkpoints"),
description="Base path where checkpoints will be stored",
)
database_url: str = Field(
default=f'sqlite:///{Path(__file__).parent / "tinker.db"}',
- description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses TX_DATABASE_URL env var or defaults to SQLite",
- json_schema_extra={"argparse_type": str, "env_var": "TX_DATABASE_URL"},
+ description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses SKYRL_DATABASE_URL env var or defaults to SQLite",
+ json_schema_extra={"argparse_type": str, "env_var": "SKYRL_DATABASE_URL"},
)
external_inference_url: str | None = Field(
default=None,
diff -ru /tmp/diff_tx/tinker/engine.py /tmp/diff_skyrl/tinker/engine.py
--- /tmp/diff_tx/tinker/engine.py 2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/engine.py 2026-02-11 20:12:43.404465987 +0000
@@ -22,7 +22,7 @@
)
from skyrl.tinker import types
from skyrl.tinker.config import EngineConfig, add_model
-from skyrl.tinker.backends.utils import log_timing
+from skyrl.backends.utils import log_timing
from skyrl.utils.log import logger
@@ -160,21 +160,21 @@
def get_backend_classes(backend_name: str):
"""Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
if backend_name == "jax":
- from skyrl.tinker.backends.jax import JaxBackend, JaxBackendConfig
+ from skyrl.backends.jax import JaxBackend, JaxBackendConfig
return JaxBackend, JaxBackendConfig
elif backend_name == "fsdp":
- from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
+ from skyrl.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
return SkyRLTrainBackend, FSDPBackendConfig
elif backend_name == "megatron":
- from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
+ from skyrl.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
return SkyRLTrainBackend, MegatronBackendConfig
else:
raise ValueError(
f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. "
- f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])"
+ f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])"
)
@@ -463,9 +463,7 @@
return unloaded_count
- def process_optim_step(
- self, model_id: str, request_data: types.OptimStepInput
- ) -> types.OptimStepOutput | types.ErrorResponse:
+ def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse:
"""Process an optim_step request and apply accumulated gradients."""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
@@ -487,9 +485,7 @@
prepared = prepare_sample_batch(requests, self.config.checkpoints_base)
return self.backend.sample(prepared)
- def process_load_weights(
- self, model_id: str, request_data: types.LoadWeightsInput
- ) -> types.LoadWeightsOutput | types.ErrorResponse:
+ def process_load_weights(self, model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput | types.ErrorResponse:
"""Loads a clean, trimmed training checkpoint."""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
@@ -502,9 +498,7 @@
return types.LoadWeightsOutput(type="load_weights")
- def process_save_weights(
- self, model_id: str, request_data: types.SaveWeightsInput
- ) -> types.SaveWeightsOutput | types.ErrorResponse:
+ def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInput) -> types.SaveWeightsOutput | types.ErrorResponse:
"""
Saves a clean training checkpoint by converting the trimmed NNX graph
to a pure dictionary before serialization, following official Flax docs.
@@ -684,7 +678,7 @@
def main():
"""Entry point for the background engine."""
# Create argument parser and add Pydantic model fields
- parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests")
+ parser = argparse.ArgumentParser(description="SkyRL tinker engine for processing requests")
add_model(parser, EngineConfig)
# Parse command-line arguments
Only in /tmp/diff_skyrl/: tx
diff -ru /tmp/diff_tx/utils/generator.py /tmp/diff_skyrl/utils/generator.py
--- /tmp/diff_tx/utils/generator.py 2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/generator.py 2026-02-11 20:08:43.702536973 +0000
@@ -303,7 +303,7 @@
batch_size, prompt_length = input_ids.shape
assert len(sampling_params) == batch_size
max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params)
- max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
+ max_length = skyrl.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
temperatures = jnp.array([sampling_param.temperature for sampling_param in sampling_params])
top_k_values = jnp.array([sampling_param.top_k for sampling_param in sampling_params], dtype=jnp.int32)
top_p_values = jnp.array([sampling_param.top_p for sampling_param in sampling_params], dtype=jnp.float32)
diff -ru /tmp/diff_tx/utils/log.py /tmp/diff_skyrl/utils/log.py
--- /tmp/diff_tx/utils/log.py 2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/log.py 2026-02-11 20:08:43.702536973 +0000
@@ -31,7 +31,7 @@
def _setup_root_logger() -> None:
- logger = logging.getLogger("tx")
+ logger = logging.getLogger("skyrl")
logger.setLevel(logging.DEBUG)
logger.propagate = False # Prevent propagation to root logger
logger.addHandler(_create_rich_handler())
@@ -66,7 +66,7 @@
def add_file_handler(path: Path | str, level: int = logging.DEBUG, *, print_path: bool = True) -> None:
- logger = logging.getLogger("tx")
+ logger = logging.getLogger("skyrl")
handler = logging.FileHandler(path)
handler.setLevel(level)
formatter = logging.Formatter(LOG_FORMAT)
@@ -77,7 +77,7 @@
_setup_root_logger()
-logger = logging.getLogger("tx")
+logger = logging.getLogger("skyrl")
class ExperimentTracker(str, Enum):
diff -ru /tmp/diff_tx/utils/logits_processor.py /tmp/diff_skyrl/utils/logits_processor.py
--- /tmp/diff_tx/utils/logits_processor.py 2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/logits_processor.py 2026-02-11 20:08:43.702536973 +0000
@@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
# lm_head: (hidden_states, adapter_indices) -> logits
diff -ru /tmp/diff_tx/utils/models.py /tmp/diff_skyrl/utils/models.py
--- /tmp/diff_tx/utils/models.py 2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/models.py 2026-02-11 20:08:43.702536973 +0000
@@ -16,7 +16,7 @@
from transformers import PretrainedConfig
import peft
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
from skyrl.utils.log import logger
from skyrl.utils.storage import download_and_unpack, pack_and_upload
from skyrl.tinker.types import LoraConfig
@@ -61,17 +61,17 @@
def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]:
"Get the correct model class based on the config."
- import skyrl.models.llama3
- import skyrl.models.qwen3
- import skyrl.models.deepseekv3
+ import skyrl.tx.models.llama3
+ import skyrl.tx.models.qwen3
+ import skyrl.tx.models.deepseekv3
for architecture in config.architectures or []:
- if hasattr(tx.models.llama3, architecture):
- return getattr(tx.models.llama3, architecture)
- if hasattr(tx.models.qwen3, architecture):
- return getattr(tx.models.qwen3, architecture)
- if hasattr(tx.models.deepseekv3, architecture):
- return getattr(tx.models.deepseekv3, architecture)
+ if hasattr(skyrl.tx.models.llama3, architecture):
+ return getattr(skyrl.tx.models.llama3, architecture)
+ if hasattr(skyrl.tx.models.qwen3, architecture):
+ return getattr(skyrl.tx.models.qwen3, architecture)
+ if hasattr(skyrl.tx.models.deepseekv3, architecture):
+ return getattr(skyrl.tx.models.deepseekv3, architecture)
raise ValueError(f"None of the architectures {config.architectures} is currently supported.") |
|
Same as above but comparison for
command for diff: rm -rf /tmp/diff_tx_backends /tmp/diff_skyrl_backends && \
mkdir -p /tmp/diff_tx_backends /tmp/diff_skyrl_backends && \
rsync -a --exclude='__pycache__' SkyRL/skyrl-tx/tx/tinker/backends/ /tmp/diff_tx_backends/ && \
rsync -a --exclude='__pycache__' SkyRL/skyrl/skyrl/backends/ /tmp/diff_skyrl_backends/ && \
find /tmp/diff_tx_backends -name '*.py' -exec sed -i 's/from tx\./from skyrl./g; s/import tx\./import skyrl./g; s/import tx$/import skyrl/g; s/from skyrl\.tinker\.backends\./from skyrl.backends./g; s/skyrl\.tinker\.backends\./skyrl.backends./g' {} + && \
diff -ru /tmp/diff_tx_backends/ /tmp/diff_skyrl_backends/ || truediff -ru /tmp/diff_tx_backends/ /tmp/diff_skyrl_backends/ || true
diff -ru /tmp/diff_tx_backends/jax.py /tmp/diff_skyrl_backends/jax.py
--- /tmp/diff_tx_backends/jax.py 2026-02-11 20:18:11.240984687 +0000
+++ /tmp/diff_skyrl_backends/jax.py 2026-02-11 20:08:43.698536673 +0000
@@ -6,18 +6,18 @@
In multi-host mode, process 0 (coordinator) runs the engine with JaxBackend,
which broadcasts commands to workers. Workers run separately using `run_worker()`
-or by running this module directly with `python -m tx.tinker.backends.jax`.
+or by running this module directly with `python -m skyrl.backends.jax`.
Usage:
# Coordinator (process 0) - runs the full engine:
- uv run -m tx.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
+ uv run -m skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
"coordinator_address": "localhost:7777",
"num_processes": 2,
...
}'
# Workers (process 1+) - run only the worker loop (receives config from coordinator):
- uv run -m tx.tinker.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
+ uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
"""
import time
@@ -36,8 +36,8 @@
from pydantic import BaseModel, Field, TypeAdapter
from transformers import AutoTokenizer, PretrainedConfig
-from skyrl.models.configs import Qwen3Config
-from skyrl.layers.lora import clear_lora_adapter, init_lora_adapter
+from skyrl.tx.models.configs import Qwen3Config
+from skyrl.tx.layers.lora import clear_lora_adapter, init_lora_adapter
from skyrl.tinker import types
from skyrl.backends.backend import AbstractBackend
from skyrl.backends.utils import pad, pad_batch, pad_to_fsdp
@@ -474,8 +474,8 @@
# Create optimizer
with jax.set_mesh(self.mesh):
- tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0)
- self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.model.is_lora_param)
+ optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0)
+ self.optimizers[model_id] = nnx.Optimizer(self.model, optimizer, wrt=self.model.is_lora_param)
# Configure adapter
init_lora_adapter(self.model, adapter_index, lora_config)
@@ -1069,7 +1069,7 @@
"""Entry point for running as a worker process."""
import argparse
- parser = argparse.ArgumentParser(description="SkyRL tx tinker worker process")
+ parser = argparse.ArgumentParser(description="SkyRL tinker worker process")
parser.add_argument(
"--coordinator-address",
required=True,
diff -ru /tmp/diff_tx_backends/skyrl_train.py /tmp/diff_skyrl_backends/skyrl_train.py
--- /tmp/diff_tx_backends/skyrl_train.py 2026-02-11 20:18:11.240984687 +0000
+++ /tmp/diff_skyrl_backends/skyrl_train.py 2026-02-11 20:08:43.698536673 +0000
@@ -5,6 +5,7 @@
"""
import asyncio
+import math
import os
import tarfile
import tempfile
@@ -287,6 +288,39 @@
batch.metadata = {"response_length": max_response_len}
return batch
+ def _pad_batch(self, batch: TrainingInputBatch) -> tuple[TrainingInputBatch, int]:
+ """Pad the batch so its size is divisible by dp_size.
+
+ The dispatch layer splits the batch evenly across DP workers, so the
+ batch size must be a multiple of dp_size. We pad by cloning the first
+ N entries (with loss_mask zeroed) and record the pad count so callers
+ can trim the results.
+
+ Returns:
+ (padded_batch, pad_size)
+ """
+ dp_size = self._dispatch.get_lcm_dp_size()
+ pad_size = math.ceil(batch.batch_size / dp_size) * dp_size - batch.batch_size
+ if pad_size == 0:
+ return batch, 0
+
+ new_tensors = {}
+ for key, tensor in batch.items():
+ if tensor is not None:
+ if key == "loss_mask":
+ # Padding entries must not contribute to the loss
+ additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()
+ padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
+ else:
+ # Clone real data so shapes/dtypes are valid for the model
+ padding_tensor = tensor[:pad_size].clone()
+ new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
+
+ padded = TrainingInputBatch(new_tensors)
+ padded.metadata = batch.metadata
+ logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (dp_size={dp_size})")
+ return padded, pad_size
+
def _extract_metrics(self, data: dict) -> dict[str, float]:
"""Extract training metrics from dispatch return dict.
@@ -318,6 +352,8 @@
return {}
batch = self._to_training_batch(prepared_batch)
+ batch, pad_size = self._pad_batch(batch)
+
loss_fn = prepared_batch.all_loss_fns[0]
if len(set(prepared_batch.all_loss_fns)) > 1:
logger.warning(
@@ -333,6 +369,10 @@
loss_fn_config=loss_fn_config,
)
+ # Trim padding entries from loss_fn_outputs
+ if pad_size > 0 and "loss_fn_outputs" in data:
+ data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]
+
metrics = self._extract_metrics(data)
results = {}
@@ -371,10 +411,13 @@
return {}
batch = self._to_training_batch(prepared_batch)
+ original_batch_size = batch.batch_size
+ batch, pad_size = self._pad_batch(batch)
data = self._dispatch.forward("policy", batch)
# dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
- output_logprobs = data["output"]
+ # Trim padding entries from output
+ output_logprobs = data["output"][:original_batch_size]
results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices: |
Add `_pad_batch` logic and minor change in jax backend to rename var from tx -> optimizer <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1076" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
updated diff: diff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || true
Only in /tmp/diff_skyrl/: backends
Only in /tmp/diff_tx/: layers
Only in /tmp/diff_tx/: loaders
Only in /tmp/diff_tx/: models
Only in /tmp/diff_tx/: run
diff -ru /tmp/diff_tx/tinker/alembic/env.py /tmp/diff_skyrl/tinker/alembic/env.py
--- /tmp/diff_tx/tinker/alembic/env.py 2026-02-11 21:57:02.780551919 +0000
+++ /tmp/diff_skyrl/tinker/alembic/env.py 2026-02-11 21:37:38.409502381 +0000
@@ -7,7 +7,7 @@
from alembic import context
-# Add parent directory to path so we can import tx modules
+# Add parent directory to path so we can import skyrl modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
# Import SQLModel and database models
@@ -67,7 +67,7 @@
from sqlalchemy import create_engine
# Get database URL - ignore whatever is in config, use our helper
- db_url = os.environ["TX_DATABASE_URL"]
+ db_url = os.environ["SKYRL_DATABASE_URL"]
connectable = create_engine(db_url, poolclass=pool.NullPool)
with connectable.connect() as connection:
diff -ru /tmp/diff_tx/tinker/api.py /tmp/diff_skyrl/tinker/api.py
--- /tmp/diff_tx/tinker/api.py 2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/api.py 2026-02-11 21:37:38.409502381 +0000
@@ -71,7 +71,7 @@
"--extra",
app.state.engine_config.backend,
"-m",
- "tx.tinker.engine",
+ "skyrl.tinker.engine",
]
cmd.extend(config_to_argv(app.state.engine_config))
@@ -1177,7 +1177,7 @@
import uvicorn
# Parse command-line arguments
- parser = argparse.ArgumentParser(description="SkyRL tx tinker API server")
+ parser = argparse.ArgumentParser(description="SkyRL tinker API server")
add_model(parser, EngineConfig)
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
Only in /tmp/diff_tx/tinker: backends
diff -ru /tmp/diff_tx/tinker/config.py /tmp/diff_skyrl/tinker/config.py
--- /tmp/diff_tx/tinker/config.py 2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/config.py 2026-02-11 21:37:38.409502381 +0000
@@ -20,13 +20,13 @@
json_schema_extra={"argparse_type": json.loads},
)
checkpoints_base: AnyPath = Field(
- default=AnyPath("/tmp/tx_checkpoints"),
+ default=AnyPath("/tmp/skyrl_checkpoints"),
description="Base path where checkpoints will be stored",
)
database_url: str = Field(
default=f'sqlite:///{Path(__file__).parent / "tinker.db"}',
- description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses TX_DATABASE_URL env var or defaults to SQLite",
- json_schema_extra={"argparse_type": str, "env_var": "TX_DATABASE_URL"},
+ description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses SKYRL_DATABASE_URL env var or defaults to SQLite",
+ json_schema_extra={"argparse_type": str, "env_var": "SKYRL_DATABASE_URL"},
)
external_inference_url: str | None = Field(
default=None,
diff -ru /tmp/diff_tx/tinker/engine.py /tmp/diff_skyrl/tinker/engine.py
--- /tmp/diff_tx/tinker/engine.py 2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/engine.py 2026-02-11 21:37:44.589964976 +0000
@@ -22,7 +22,7 @@
)
from skyrl.tinker import types
from skyrl.tinker.config import EngineConfig, add_model
-from skyrl.tinker.backends.utils import log_timing
+from skyrl.backends.utils import log_timing
from skyrl.utils.log import logger
@@ -160,21 +160,21 @@
def get_backend_classes(backend_name: str):
"""Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
if backend_name == "jax":
- from skyrl.tinker.backends.jax import JaxBackend, JaxBackendConfig
+ from skyrl.backends.jax import JaxBackend, JaxBackendConfig
return JaxBackend, JaxBackendConfig
elif backend_name == "fsdp":
- from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
+ from skyrl.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
return SkyRLTrainBackend, FSDPBackendConfig
elif backend_name == "megatron":
- from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
+ from skyrl.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
return SkyRLTrainBackend, MegatronBackendConfig
else:
raise ValueError(
f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. "
- f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])"
+ f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])"
)
@@ -684,7 +684,7 @@
def main():
"""Entry point for the background engine."""
# Create argument parser and add Pydantic model fields
- parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests")
+ parser = argparse.ArgumentParser(description="SkyRL tinker engine for processing requests")
add_model(parser, EngineConfig)
# Parse command-line arguments
Only in /tmp/diff_skyrl/: tx
diff -ru /tmp/diff_tx/utils/generator.py /tmp/diff_skyrl/utils/generator.py
--- /tmp/diff_tx/utils/generator.py 2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/generator.py 2026-02-11 21:37:38.413502681 +0000
@@ -303,7 +303,7 @@
batch_size, prompt_length = input_ids.shape
assert len(sampling_params) == batch_size
max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params)
- max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
+ max_length = skyrl.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
temperatures = jnp.array([sampling_param.temperature for sampling_param in sampling_params])
top_k_values = jnp.array([sampling_param.top_k for sampling_param in sampling_params], dtype=jnp.int32)
top_p_values = jnp.array([sampling_param.top_p for sampling_param in sampling_params], dtype=jnp.float32)
diff -ru /tmp/diff_tx/utils/log.py /tmp/diff_skyrl/utils/log.py
--- /tmp/diff_tx/utils/log.py 2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/log.py 2026-02-11 21:37:38.413502681 +0000
@@ -31,7 +31,7 @@
def _setup_root_logger() -> None:
- logger = logging.getLogger("tx")
+ logger = logging.getLogger("skyrl")
logger.setLevel(logging.DEBUG)
logger.propagate = False # Prevent propagation to root logger
logger.addHandler(_create_rich_handler())
@@ -66,7 +66,7 @@
def add_file_handler(path: Path | str, level: int = logging.DEBUG, *, print_path: bool = True) -> None:
- logger = logging.getLogger("tx")
+ logger = logging.getLogger("skyrl")
handler = logging.FileHandler(path)
handler.setLevel(level)
formatter = logging.Formatter(LOG_FORMAT)
@@ -77,7 +77,7 @@
_setup_root_logger()
-logger = logging.getLogger("tx")
+logger = logging.getLogger("skyrl")
class ExperimentTracker(str, Enum):
diff -ru /tmp/diff_tx/utils/logits_processor.py /tmp/diff_skyrl/utils/logits_processor.py
--- /tmp/diff_tx/utils/logits_processor.py 2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/logits_processor.py 2026-02-11 21:37:38.413502681 +0000
@@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
# lm_head: (hidden_states, adapter_indices) -> logits
diff -ru /tmp/diff_tx/utils/models.py /tmp/diff_skyrl/utils/models.py
--- /tmp/diff_tx/utils/models.py 2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/models.py 2026-02-11 21:37:38.413502681 +0000
@@ -16,7 +16,7 @@
from transformers import PretrainedConfig
import peft
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
from skyrl.utils.log import logger
from skyrl.utils.storage import download_and_unpack, pack_and_upload
from skyrl.tinker.types import LoraConfig
@@ -61,17 +61,17 @@
def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]:
"Get the correct model class based on the config."
- import skyrl.models.llama3
- import skyrl.models.qwen3
- import skyrl.models.deepseekv3
+ import skyrl.tx.models.llama3
+ import skyrl.tx.models.qwen3
+ import skyrl.tx.models.deepseekv3
for architecture in config.architectures or []:
- if hasattr(tx.models.llama3, architecture):
- return getattr(tx.models.llama3, architecture)
- if hasattr(tx.models.qwen3, architecture):
- return getattr(tx.models.qwen3, architecture)
- if hasattr(tx.models.deepseekv3, architecture):
- return getattr(tx.models.deepseekv3, architecture)
+ if hasattr(skyrl.tx.models.llama3, architecture):
+ return getattr(skyrl.tx.models.llama3, architecture)
+ if hasattr(skyrl.tx.models.qwen3, architecture):
+ return getattr(skyrl.tx.models.qwen3, architecture)
+ if hasattr(skyrl.tx.models.deepseekv3, architecture):
+ return getattr(skyrl.tx.models.deepseekv3, architecture)
raise ValueError(f"None of the architectures {config.architectures} is currently supported.") |
|
updated backends diff: diff -ru /tmp/diff_tx_backends/jax.py /tmp/diff_skyrl_backends/jax.py
--- /tmp/diff_tx_backends/jax.py 2026-02-11 21:57:34.214902589 +0000
+++ /tmp/diff_skyrl_backends/jax.py 2026-02-11 21:37:38.409502381 +0000
@@ -6,18 +6,18 @@
In multi-host mode, process 0 (coordinator) runs the engine with JaxBackend,
which broadcasts commands to workers. Workers run separately using `run_worker()`
-or by running this module directly with `python -m tx.tinker.backends.jax`.
+or by running this module directly with `python -m skyrl.backends.jax`.
Usage:
# Coordinator (process 0) - runs the full engine:
- uv run -m tx.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
+ uv run -m skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
"coordinator_address": "localhost:7777",
"num_processes": 2,
...
}'
# Workers (process 1+) - run only the worker loop (receives config from coordinator):
- uv run -m tx.tinker.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
+ uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
"""
import time
@@ -36,8 +36,8 @@
from pydantic import BaseModel, Field, TypeAdapter
from transformers import AutoTokenizer, PretrainedConfig
-from skyrl.models.configs import Qwen3Config
-from skyrl.layers.lora import clear_lora_adapter, init_lora_adapter
+from skyrl.tx.models.configs import Qwen3Config
+from skyrl.tx.layers.lora import clear_lora_adapter, init_lora_adapter
from skyrl.tinker import types
from skyrl.backends.backend import AbstractBackend
from skyrl.backends.utils import pad, pad_batch, pad_to_fsdp
@@ -1069,7 +1069,7 @@
"""Entry point for running as a worker process."""
import argparse
- parser = argparse.ArgumentParser(description="SkyRL tx tinker worker process")
+ parser = argparse.ArgumentParser(description="SkyRL tinker worker process")
parser.add_argument(
"--coordinator-address",
required=True, |
This is getting the CI to work after moving the files in #1068 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1087" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This is a temporary state - the goal here is for
skyrl-trainandskyrl-txto be absorbed intoskyrlonce we validate that everything is working.