Skip to content

[WIP][skyrl] Create new skyrl folder combining tx + train#1068

Merged
erictang000 merged 19 commits intoNovaSky-AI:mainfrom
erictang000:skyrl_merge1
Feb 11, 2026
Merged

[WIP][skyrl] Create new skyrl folder combining tx + train#1068
erictang000 merged 19 commits intoNovaSky-AI:mainfrom
erictang000:skyrl_merge1

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Feb 11, 2026

This is a temporary state - the goal here is for skyrl-train and skyrl-tx to be absorbed into skyrl once we validate that everything is working.


Open with Devin

erictang000 and others added 6 commits February 10, 2026 22:59
update build_models in the SkyRLTrainBackend to only include the
supported colocate all + policy model only logic for cleanliness.

cc: @pcmoritz 
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1065"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
This was from an earlier iteration and is not used any more, now the
inference is done through
https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-tx/tx/tinker/backends/skyrl_train.py
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1066"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
NovaSky-AI#1067)

Change the training backends to the desired state of "jax", "fsdp",
"megatron".

Adds megatron deps to pyproject.toml via skyrl-train[mcore] dependency,
and add needed overrides to get installation working.

For FSDP + vLLM:

```bash
uv run --isolated --extra tinker -m tx.tinker.api  --base-model "Qwen/Qwen3-0.6B" --backend fsdp
```

For Megatron + vLLM:

```bash
uv run --isolated --extra tinker -m tx.tinker.api  --base-model "Qwen/Qwen3-0.6B" --backend megatron
```


<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1067"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 9 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a major refactoring by creating a new unified skyrl package, combining skyrl-tx and skyrl-train functionality, and includes a new FastAPI-based 'Tinker API Mock'. The new structure is well-organized with clear separation of concerns, featuring a tinker engine, tx model implementations, run scripts, and an abstract backend interface for enhanced modularity. However, a security audit identified several high and medium severity issues, primarily related to insecure handling of sensitive information in logs, potential Denial of Service via memory exhaustion during large file downloads, and broken access control due to guessable identifiers and lack of authentication on sensitive endpoints. Additionally, critical code issues include an accidentally committed log file and a non-reproducible git dependency in pyproject.toml. Further improvements are also needed in error handling, the use of internal APIs, and splitting the large api.py file into smaller modules for better maintainability.

@erictang000 erictang000 changed the title [skyrl] Create new skyrl folder combining tx + train [WIP][skyrl] Create new skyrl folder combining tx + train Feb 11, 2026
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 16 additional findings in Devin Review.

Open in Devin Review

@@ -0,0 +1,17 @@
import typer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we just don't bring the whole run folder over (main.py and train.py). That is old code before there was the tinker API, and I think it is just confusing to have it in the skyrl folder :)

If you just remove it from this PR, I think everything should still work normally, but do let me know if we need to make some modification to make it work, happy to make that in the original folder to reduce conflicts / problems :)

Otherwise, I think this transition is the right place to get rid of those files.

@pcmoritz
Copy link
Collaborator

Make sure to not commit your database files to the main repo or else everybody will have to pay the overhead going forward even if we remove them (there is kind of a painful way to purge them but let's not let it come to that).

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 23 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 29 additional findings in Devin Review.

Open in Devin Review

padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
# Clone real data so shapes/dtypes are valid for the model
padding_tensor = tensor[:pad_size].clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 _pad_batch produces wrong padding size when pad_size > batch_size

When the batch size is smaller than dp_size - 1, the padding logic in _pad_batch produces incorrect results because tensor[:pad_size].clone() returns fewer elements than pad_size when pad_size > batch_size. This causes inconsistent tensor sizes across batch keys.

Root Cause and Impact

At skyrl/skyrl/backends/skyrl_train.py:316, padding for non-loss_mask keys is done via:

padding_tensor = tensor[:pad_size].clone()

When pad_size > tensor.shape[0] (i.e., pad_size > batch_size), PyTorch's slicing returns only batch_size elements, not pad_size. Meanwhile, loss_mask at line 313 correctly creates a torch.zeros(pad_size, ...) tensor of the right size.

For example, with batch_size=1 and dp_size=8:

  • pad_size = ceil(1/8)*8 - 1 = 7
  • loss_mask gets torch.zeros(7, ...) → final size = 1 + 7 = 8 ✓
  • sequences gets tensor[:7].clone() → only 1 element → final size = 1 + 1 = 2 ✗

This produces a TrainingInputBatch with mismatched tensor sizes (e.g., loss_mask has 8 rows but sequences has 2), which will crash when the dispatch layer tries to split the batch across DP workers. Any batch with fewer examples than dp_size - 1 triggers this bug.

Suggested change
padding_tensor = tensor[:pad_size].clone()
repeat_count = (pad_size + tensor.shape[0] - 1) // tensor.shape[0]
padding_tensor = tensor.repeat((repeat_count,) + (1,) * (tensor.ndim - 1))[:pad_size].clone()
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@erictang000
Copy link
Collaborator Author

erictang000 commented Feb 11, 2026

Diff between skyrl-tx/ and skyrl/ (mostly) normalized for differences that are just renaming tx -> skyrl using command

rm -rf /tmp/diff_tx /tmp/diff_skyrl && \
mkdir -p /tmp/diff_tx /tmp/diff_skyrl && \
rsync -a --exclude='__pycache__' --exclude='*.db*' --exclude='uv.lock' --exclude='.git' SkyRL/skyrl-tx/tx/ /tmp/diff_tx/ && \
rsync -a --exclude='__pycache__' --exclude='*.db*' --exclude='uv.lock' --exclude='.git' SkyRL/skyrl/skyrl/ /tmp/diff_skyrl/ && \
find /tmp/diff_tx -name '*.py' -exec sed -i 's/from tx\./from skyrl./g; s/import tx\./import skyrl./g; s/import tx$/import skyrl/g' {} + && \
diff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || true
diff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || true
Only in /tmp/diff_skyrl/: backends
Only in /tmp/diff_tx/: layers
Only in /tmp/diff_tx/: loaders
Only in /tmp/diff_tx/: models
Only in /tmp/diff_tx/: run
diff -ru /tmp/diff_tx/tinker/alembic/env.py /tmp/diff_skyrl/tinker/alembic/env.py
--- /tmp/diff_tx/tinker/alembic/env.py  2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/alembic/env.py       2026-02-11 20:08:43.698536673 +0000
@@ -7,7 +7,7 @@
 
 from alembic import context
 
-# Add parent directory to path so we can import tx modules
+# Add parent directory to path so we can import skyrl modules
 sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
 
 # Import SQLModel and database models
@@ -67,7 +67,7 @@
     from sqlalchemy import create_engine
 
     # Get database URL - ignore whatever is in config, use our helper
-    db_url = os.environ["TX_DATABASE_URL"]
+    db_url = os.environ["SKYRL_DATABASE_URL"]
     connectable = create_engine(db_url, poolclass=pool.NullPool)
 
     with connectable.connect() as connection:
diff -ru /tmp/diff_tx/tinker/api.py /tmp/diff_skyrl/tinker/api.py
--- /tmp/diff_tx/tinker/api.py  2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/api.py       2026-02-11 20:08:43.698536673 +0000
@@ -71,7 +71,7 @@
         "--extra",
         app.state.engine_config.backend,
         "-m",
-        "tx.tinker.engine",
+        "skyrl.tinker.engine",
     ]
     cmd.extend(config_to_argv(app.state.engine_config))
 
@@ -1177,7 +1177,7 @@
     import uvicorn
 
     # Parse command-line arguments
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker API server")
+    parser = argparse.ArgumentParser(description="SkyRL tinker API server")
     add_model(parser, EngineConfig)
     parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
     parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
Only in /tmp/diff_tx/tinker: backends
diff -ru /tmp/diff_tx/tinker/config.py /tmp/diff_skyrl/tinker/config.py
--- /tmp/diff_tx/tinker/config.py       2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/config.py    2026-02-11 20:08:43.698536673 +0000
@@ -20,13 +20,13 @@
         json_schema_extra={"argparse_type": json.loads},
     )
     checkpoints_base: AnyPath = Field(
-        default=AnyPath("/tmp/tx_checkpoints"),
+        default=AnyPath("/tmp/skyrl_checkpoints"),
         description="Base path where checkpoints will be stored",
     )
     database_url: str = Field(
         default=f'sqlite:///{Path(__file__).parent / "tinker.db"}',
-        description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses TX_DATABASE_URL env var or defaults to SQLite",
-        json_schema_extra={"argparse_type": str, "env_var": "TX_DATABASE_URL"},
+        description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses SKYRL_DATABASE_URL env var or defaults to SQLite",
+        json_schema_extra={"argparse_type": str, "env_var": "SKYRL_DATABASE_URL"},
     )
     external_inference_url: str | None = Field(
         default=None,
diff -ru /tmp/diff_tx/tinker/engine.py /tmp/diff_skyrl/tinker/engine.py
--- /tmp/diff_tx/tinker/engine.py       2026-02-11 20:14:18.783597956 +0000
+++ /tmp/diff_skyrl/tinker/engine.py    2026-02-11 20:12:43.404465987 +0000
@@ -22,7 +22,7 @@
 )
 from skyrl.tinker import types
 from skyrl.tinker.config import EngineConfig, add_model
-from skyrl.tinker.backends.utils import log_timing
+from skyrl.backends.utils import log_timing
 from skyrl.utils.log import logger
 
 
@@ -160,21 +160,21 @@
 def get_backend_classes(backend_name: str):
     """Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
     if backend_name == "jax":
-        from skyrl.tinker.backends.jax import JaxBackend, JaxBackendConfig
+        from skyrl.backends.jax import JaxBackend, JaxBackendConfig
 
         return JaxBackend, JaxBackendConfig
     elif backend_name == "fsdp":
-        from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
+        from skyrl.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
 
         return SkyRLTrainBackend, FSDPBackendConfig
     elif backend_name == "megatron":
-        from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
+        from skyrl.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
 
         return SkyRLTrainBackend, MegatronBackendConfig
     else:
         raise ValueError(
             f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. "
-            f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])"
+            f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])"
         )
 
 
@@ -463,9 +463,7 @@
 
         return unloaded_count
 
-    def process_optim_step(
-        self, model_id: str, request_data: types.OptimStepInput
-    ) -> types.OptimStepOutput | types.ErrorResponse:
+    def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse:
         """Process an optim_step request and apply accumulated gradients."""
         if not self.backend.has_model(model_id):
             return _model_not_found_error(model_id)
@@ -487,9 +485,7 @@
         prepared = prepare_sample_batch(requests, self.config.checkpoints_base)
         return self.backend.sample(prepared)
 
-    def process_load_weights(
-        self, model_id: str, request_data: types.LoadWeightsInput
-    ) -> types.LoadWeightsOutput | types.ErrorResponse:
+    def process_load_weights(self, model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput | types.ErrorResponse:
         """Loads a clean, trimmed training checkpoint."""
         if not self.backend.has_model(model_id):
             return _model_not_found_error(model_id)
@@ -502,9 +498,7 @@
 
         return types.LoadWeightsOutput(type="load_weights")
 
-    def process_save_weights(
-        self, model_id: str, request_data: types.SaveWeightsInput
-    ) -> types.SaveWeightsOutput | types.ErrorResponse:
+    def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInput) -> types.SaveWeightsOutput | types.ErrorResponse:
         """
         Saves a clean training checkpoint by converting the trimmed NNX graph
         to a pure dictionary before serialization, following official Flax docs.
@@ -684,7 +678,7 @@
 def main():
     """Entry point for the background engine."""
     # Create argument parser and add Pydantic model fields
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests")
+    parser = argparse.ArgumentParser(description="SkyRL tinker engine for processing requests")
     add_model(parser, EngineConfig)
 
     # Parse command-line arguments
Only in /tmp/diff_skyrl/: tx
diff -ru /tmp/diff_tx/utils/generator.py /tmp/diff_skyrl/utils/generator.py
--- /tmp/diff_tx/utils/generator.py     2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/generator.py  2026-02-11 20:08:43.702536973 +0000
@@ -303,7 +303,7 @@
         batch_size, prompt_length = input_ids.shape
         assert len(sampling_params) == batch_size
         max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params)
-        max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
+        max_length = skyrl.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
         temperatures = jnp.array([sampling_param.temperature for sampling_param in sampling_params])
         top_k_values = jnp.array([sampling_param.top_k for sampling_param in sampling_params], dtype=jnp.int32)
         top_p_values = jnp.array([sampling_param.top_p for sampling_param in sampling_params], dtype=jnp.float32)
diff -ru /tmp/diff_tx/utils/log.py /tmp/diff_skyrl/utils/log.py
--- /tmp/diff_tx/utils/log.py   2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/log.py        2026-02-11 20:08:43.702536973 +0000
@@ -31,7 +31,7 @@
 
 
 def _setup_root_logger() -> None:
-    logger = logging.getLogger("tx")
+    logger = logging.getLogger("skyrl")
     logger.setLevel(logging.DEBUG)
     logger.propagate = False  # Prevent propagation to root logger
     logger.addHandler(_create_rich_handler())
@@ -66,7 +66,7 @@
 
 
 def add_file_handler(path: Path | str, level: int = logging.DEBUG, *, print_path: bool = True) -> None:
-    logger = logging.getLogger("tx")
+    logger = logging.getLogger("skyrl")
     handler = logging.FileHandler(path)
     handler.setLevel(level)
     formatter = logging.Formatter(LOG_FORMAT)
@@ -77,7 +77,7 @@
 
 
 _setup_root_logger()
-logger = logging.getLogger("tx")
+logger = logging.getLogger("skyrl")
 
 
 class ExperimentTracker(str, Enum):
diff -ru /tmp/diff_tx/utils/logits_processor.py /tmp/diff_skyrl/utils/logits_processor.py
--- /tmp/diff_tx/utils/logits_processor.py      2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/logits_processor.py   2026-02-11 20:08:43.702536973 +0000
@@ -5,7 +5,7 @@
 
 import jax
 import jax.numpy as jnp
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
 
 
 # lm_head: (hidden_states, adapter_indices) -> logits
diff -ru /tmp/diff_tx/utils/models.py /tmp/diff_skyrl/utils/models.py
--- /tmp/diff_tx/utils/models.py        2026-02-11 20:14:18.779597657 +0000
+++ /tmp/diff_skyrl/utils/models.py     2026-02-11 20:08:43.702536973 +0000
@@ -16,7 +16,7 @@
 from transformers import PretrainedConfig
 import peft
 
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
 from skyrl.utils.log import logger
 from skyrl.utils.storage import download_and_unpack, pack_and_upload
 from skyrl.tinker.types import LoraConfig
@@ -61,17 +61,17 @@
 
 def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]:
     "Get the correct model class based on the config."
-    import skyrl.models.llama3
-    import skyrl.models.qwen3
-    import skyrl.models.deepseekv3
+    import skyrl.tx.models.llama3
+    import skyrl.tx.models.qwen3
+    import skyrl.tx.models.deepseekv3
 
     for architecture in config.architectures or []:
-        if hasattr(tx.models.llama3, architecture):
-            return getattr(tx.models.llama3, architecture)
-        if hasattr(tx.models.qwen3, architecture):
-            return getattr(tx.models.qwen3, architecture)
-        if hasattr(tx.models.deepseekv3, architecture):
-            return getattr(tx.models.deepseekv3, architecture)
+        if hasattr(skyrl.tx.models.llama3, architecture):
+            return getattr(skyrl.tx.models.llama3, architecture)
+        if hasattr(skyrl.tx.models.qwen3, architecture):
+            return getattr(skyrl.tx.models.qwen3, architecture)
+        if hasattr(skyrl.tx.models.deepseekv3, architecture):
+            return getattr(skyrl.tx.models.deepseekv3, architecture)
 
     raise ValueError(f"None of the architectures {config.architectures} is currently supported.")

@erictang000
Copy link
Collaborator Author

erictang000 commented Feb 11, 2026

Same as above but comparison for backends folder.

skyrl/backends contains some extra batch padding logic that we need to add. starting to just make the changes there instead of merging to skyrl-tx first.

command for diff:

rm -rf /tmp/diff_tx_backends /tmp/diff_skyrl_backends && \
mkdir -p /tmp/diff_tx_backends /tmp/diff_skyrl_backends && \
rsync -a --exclude='__pycache__' SkyRL/skyrl-tx/tx/tinker/backends/ /tmp/diff_tx_backends/ && \
rsync -a --exclude='__pycache__' SkyRL/skyrl/skyrl/backends/ /tmp/diff_skyrl_backends/ && \
find /tmp/diff_tx_backends -name '*.py' -exec sed -i 's/from tx\./from skyrl./g; s/import tx\./import skyrl./g; s/import tx$/import skyrl/g; s/from skyrl\.tinker\.backends\./from skyrl.backends./g; s/skyrl\.tinker\.backends\./skyrl.backends./g' {} + && \
diff -ru /tmp/diff_tx_backends/ /tmp/diff_skyrl_backends/ || true
diff -ru /tmp/diff_tx_backends/ /tmp/diff_skyrl_backends/ || true
diff -ru /tmp/diff_tx_backends/jax.py /tmp/diff_skyrl_backends/jax.py
--- /tmp/diff_tx_backends/jax.py        2026-02-11 20:18:11.240984687 +0000
+++ /tmp/diff_skyrl_backends/jax.py     2026-02-11 20:08:43.698536673 +0000
@@ -6,18 +6,18 @@
 
 In multi-host mode, process 0 (coordinator) runs the engine with JaxBackend,
 which broadcasts commands to workers. Workers run separately using `run_worker()`
-or by running this module directly with `python -m tx.tinker.backends.jax`.
+or by running this module directly with `python -m skyrl.backends.jax`.
 
 Usage:
     # Coordinator (process 0) - runs the full engine:
-    uv run -m tx.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
+    uv run -m skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
         "coordinator_address": "localhost:7777",
         "num_processes": 2,
         ...
     }'
 
     # Workers (process 1+) - run only the worker loop (receives config from coordinator):
-    uv run -m tx.tinker.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
+    uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
 """
 
 import time
@@ -36,8 +36,8 @@
 from pydantic import BaseModel, Field, TypeAdapter
 from transformers import AutoTokenizer, PretrainedConfig
 
-from skyrl.models.configs import Qwen3Config
-from skyrl.layers.lora import clear_lora_adapter, init_lora_adapter
+from skyrl.tx.models.configs import Qwen3Config
+from skyrl.tx.layers.lora import clear_lora_adapter, init_lora_adapter
 from skyrl.tinker import types
 from skyrl.backends.backend import AbstractBackend
 from skyrl.backends.utils import pad, pad_batch, pad_to_fsdp
@@ -474,8 +474,8 @@
 
         # Create optimizer
         with jax.set_mesh(self.mesh):
-            tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0)
-            self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.model.is_lora_param)
+            optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0)
+            self.optimizers[model_id] = nnx.Optimizer(self.model, optimizer, wrt=self.model.is_lora_param)
 
         # Configure adapter
         init_lora_adapter(self.model, adapter_index, lora_config)
@@ -1069,7 +1069,7 @@
     """Entry point for running as a worker process."""
     import argparse
 
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker worker process")
+    parser = argparse.ArgumentParser(description="SkyRL tinker worker process")
     parser.add_argument(
         "--coordinator-address",
         required=True,
diff -ru /tmp/diff_tx_backends/skyrl_train.py /tmp/diff_skyrl_backends/skyrl_train.py
--- /tmp/diff_tx_backends/skyrl_train.py        2026-02-11 20:18:11.240984687 +0000
+++ /tmp/diff_skyrl_backends/skyrl_train.py     2026-02-11 20:08:43.698536673 +0000
@@ -5,6 +5,7 @@
 """
 
 import asyncio
+import math
 import os
 import tarfile
 import tempfile
@@ -287,6 +288,39 @@
         batch.metadata = {"response_length": max_response_len}
         return batch
 
+    def _pad_batch(self, batch: TrainingInputBatch) -> tuple[TrainingInputBatch, int]:
+        """Pad the batch so its size is divisible by dp_size.
+
+        The dispatch layer splits the batch evenly across DP workers, so the
+        batch size must be a multiple of dp_size.  We pad by cloning the first
+        N entries (with loss_mask zeroed) and record the pad count so callers
+        can trim the results.
+
+        Returns:
+            (padded_batch, pad_size)
+        """
+        dp_size = self._dispatch.get_lcm_dp_size()
+        pad_size = math.ceil(batch.batch_size / dp_size) * dp_size - batch.batch_size
+        if pad_size == 0:
+            return batch, 0
+
+        new_tensors = {}
+        for key, tensor in batch.items():
+            if tensor is not None:
+                if key == "loss_mask":
+                    # Padding entries must not contribute to the loss
+                    additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()
+                    padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
+                else:
+                    # Clone real data so shapes/dtypes are valid for the model
+                    padding_tensor = tensor[:pad_size].clone()
+                new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
+
+        padded = TrainingInputBatch(new_tensors)
+        padded.metadata = batch.metadata
+        logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (dp_size={dp_size})")
+        return padded, pad_size
+
     def _extract_metrics(self, data: dict) -> dict[str, float]:
         """Extract training metrics from dispatch return dict.
 
@@ -318,6 +352,8 @@
             return {}
 
         batch = self._to_training_batch(prepared_batch)
+        batch, pad_size = self._pad_batch(batch)
+
         loss_fn = prepared_batch.all_loss_fns[0]
         if len(set(prepared_batch.all_loss_fns)) > 1:
             logger.warning(
@@ -333,6 +369,10 @@
             loss_fn_config=loss_fn_config,
         )
 
+        # Trim padding entries from loss_fn_outputs
+        if pad_size > 0 and "loss_fn_outputs" in data:
+            data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]
+
         metrics = self._extract_metrics(data)
 
         results = {}
@@ -371,10 +411,13 @@
             return {}
 
         batch = self._to_training_batch(prepared_batch)
+        original_batch_size = batch.batch_size
+        batch, pad_size = self._pad_batch(batch)
         data = self._dispatch.forward("policy", batch)
 
         # dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
-        output_logprobs = data["output"]
+        # Trim padding entries from output
+        output_logprobs = data["output"][:original_batch_size]
 
         results = {}
         for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:

erictang000 added a commit that referenced this pull request Feb 11, 2026
Add `_pad_batch` logic and minor change in jax backend to rename var
from tx -> optimizer
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1076"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@erictang000
Copy link
Collaborator Author

updated diff:

diff -ru /tmp/diff_tx/ /tmp/diff_skyrl/ || true
Only in /tmp/diff_skyrl/: backends
Only in /tmp/diff_tx/: layers
Only in /tmp/diff_tx/: loaders
Only in /tmp/diff_tx/: models
Only in /tmp/diff_tx/: run
diff -ru /tmp/diff_tx/tinker/alembic/env.py /tmp/diff_skyrl/tinker/alembic/env.py
--- /tmp/diff_tx/tinker/alembic/env.py  2026-02-11 21:57:02.780551919 +0000
+++ /tmp/diff_skyrl/tinker/alembic/env.py       2026-02-11 21:37:38.409502381 +0000
@@ -7,7 +7,7 @@
 
 from alembic import context
 
-# Add parent directory to path so we can import tx modules
+# Add parent directory to path so we can import skyrl modules
 sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
 
 # Import SQLModel and database models
@@ -67,7 +67,7 @@
     from sqlalchemy import create_engine
 
     # Get database URL - ignore whatever is in config, use our helper
-    db_url = os.environ["TX_DATABASE_URL"]
+    db_url = os.environ["SKYRL_DATABASE_URL"]
     connectable = create_engine(db_url, poolclass=pool.NullPool)
 
     with connectable.connect() as connection:
diff -ru /tmp/diff_tx/tinker/api.py /tmp/diff_skyrl/tinker/api.py
--- /tmp/diff_tx/tinker/api.py  2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/api.py       2026-02-11 21:37:38.409502381 +0000
@@ -71,7 +71,7 @@
         "--extra",
         app.state.engine_config.backend,
         "-m",
-        "tx.tinker.engine",
+        "skyrl.tinker.engine",
     ]
     cmd.extend(config_to_argv(app.state.engine_config))
 
@@ -1177,7 +1177,7 @@
     import uvicorn
 
     # Parse command-line arguments
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker API server")
+    parser = argparse.ArgumentParser(description="SkyRL tinker API server")
     add_model(parser, EngineConfig)
     parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
     parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
Only in /tmp/diff_tx/tinker: backends
diff -ru /tmp/diff_tx/tinker/config.py /tmp/diff_skyrl/tinker/config.py
--- /tmp/diff_tx/tinker/config.py       2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/config.py    2026-02-11 21:37:38.409502381 +0000
@@ -20,13 +20,13 @@
         json_schema_extra={"argparse_type": json.loads},
     )
     checkpoints_base: AnyPath = Field(
-        default=AnyPath("/tmp/tx_checkpoints"),
+        default=AnyPath("/tmp/skyrl_checkpoints"),
         description="Base path where checkpoints will be stored",
     )
     database_url: str = Field(
         default=f'sqlite:///{Path(__file__).parent / "tinker.db"}',
-        description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses TX_DATABASE_URL env var or defaults to SQLite",
-        json_schema_extra={"argparse_type": str, "env_var": "TX_DATABASE_URL"},
+        description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses SKYRL_DATABASE_URL env var or defaults to SQLite",
+        json_schema_extra={"argparse_type": str, "env_var": "SKYRL_DATABASE_URL"},
     )
     external_inference_url: str | None = Field(
         default=None,
diff -ru /tmp/diff_tx/tinker/engine.py /tmp/diff_skyrl/tinker/engine.py
--- /tmp/diff_tx/tinker/engine.py       2026-02-11 21:57:02.776551620 +0000
+++ /tmp/diff_skyrl/tinker/engine.py    2026-02-11 21:37:44.589964976 +0000
@@ -22,7 +22,7 @@
 )
 from skyrl.tinker import types
 from skyrl.tinker.config import EngineConfig, add_model
-from skyrl.tinker.backends.utils import log_timing
+from skyrl.backends.utils import log_timing
 from skyrl.utils.log import logger
 
 
@@ -160,21 +160,21 @@
 def get_backend_classes(backend_name: str):
     """Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
     if backend_name == "jax":
-        from skyrl.tinker.backends.jax import JaxBackend, JaxBackendConfig
+        from skyrl.backends.jax import JaxBackend, JaxBackendConfig
 
         return JaxBackend, JaxBackendConfig
     elif backend_name == "fsdp":
-        from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
+        from skyrl.backends.skyrl_train import SkyRLTrainBackend, FSDPBackendConfig
 
         return SkyRLTrainBackend, FSDPBackendConfig
     elif backend_name == "megatron":
-        from skyrl.tinker.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
+        from skyrl.backends.skyrl_train import SkyRLTrainBackend, MegatronBackendConfig
 
         return SkyRLTrainBackend, MegatronBackendConfig
     else:
         raise ValueError(
             f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. "
-            f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])"
+            f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])"
         )
 
 
@@ -684,7 +684,7 @@
 def main():
     """Entry point for the background engine."""
     # Create argument parser and add Pydantic model fields
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests")
+    parser = argparse.ArgumentParser(description="SkyRL tinker engine for processing requests")
     add_model(parser, EngineConfig)
 
     # Parse command-line arguments
Only in /tmp/diff_skyrl/: tx
diff -ru /tmp/diff_tx/utils/generator.py /tmp/diff_skyrl/utils/generator.py
--- /tmp/diff_tx/utils/generator.py     2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/generator.py  2026-02-11 21:37:38.413502681 +0000
@@ -303,7 +303,7 @@
         batch_size, prompt_length = input_ids.shape
         assert len(sampling_params) == batch_size
         max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params)
-        max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
+        max_length = skyrl.utils.models.round_up_seq_len(prompt_length + max_new_tokens)
         temperatures = jnp.array([sampling_param.temperature for sampling_param in sampling_params])
         top_k_values = jnp.array([sampling_param.top_k for sampling_param in sampling_params], dtype=jnp.int32)
         top_p_values = jnp.array([sampling_param.top_p for sampling_param in sampling_params], dtype=jnp.float32)
diff -ru /tmp/diff_tx/utils/log.py /tmp/diff_skyrl/utils/log.py
--- /tmp/diff_tx/utils/log.py   2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/log.py        2026-02-11 21:37:38.413502681 +0000
@@ -31,7 +31,7 @@
 
 
 def _setup_root_logger() -> None:
-    logger = logging.getLogger("tx")
+    logger = logging.getLogger("skyrl")
     logger.setLevel(logging.DEBUG)
     logger.propagate = False  # Prevent propagation to root logger
     logger.addHandler(_create_rich_handler())
@@ -66,7 +66,7 @@
 
 
 def add_file_handler(path: Path | str, level: int = logging.DEBUG, *, print_path: bool = True) -> None:
-    logger = logging.getLogger("tx")
+    logger = logging.getLogger("skyrl")
     handler = logging.FileHandler(path)
     handler.setLevel(level)
     formatter = logging.Formatter(LOG_FORMAT)
@@ -77,7 +77,7 @@
 
 
 _setup_root_logger()
-logger = logging.getLogger("tx")
+logger = logging.getLogger("skyrl")
 
 
 class ExperimentTracker(str, Enum):
diff -ru /tmp/diff_tx/utils/logits_processor.py /tmp/diff_skyrl/utils/logits_processor.py
--- /tmp/diff_tx/utils/logits_processor.py      2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/logits_processor.py   2026-02-11 21:37:38.413502681 +0000
@@ -5,7 +5,7 @@
 
 import jax
 import jax.numpy as jnp
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
 
 
 # lm_head: (hidden_states, adapter_indices) -> logits
diff -ru /tmp/diff_tx/utils/models.py /tmp/diff_skyrl/utils/models.py
--- /tmp/diff_tx/utils/models.py        2026-02-11 21:57:02.772551321 +0000
+++ /tmp/diff_skyrl/utils/models.py     2026-02-11 21:37:38.413502681 +0000
@@ -16,7 +16,7 @@
 from transformers import PretrainedConfig
 import peft
 
-from skyrl.models.configs import ModelConfig
+from skyrl.tx.models.configs import ModelConfig
 from skyrl.utils.log import logger
 from skyrl.utils.storage import download_and_unpack, pack_and_upload
 from skyrl.tinker.types import LoraConfig
@@ -61,17 +61,17 @@
 
 def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]:
     "Get the correct model class based on the config."
-    import skyrl.models.llama3
-    import skyrl.models.qwen3
-    import skyrl.models.deepseekv3
+    import skyrl.tx.models.llama3
+    import skyrl.tx.models.qwen3
+    import skyrl.tx.models.deepseekv3
 
     for architecture in config.architectures or []:
-        if hasattr(tx.models.llama3, architecture):
-            return getattr(tx.models.llama3, architecture)
-        if hasattr(tx.models.qwen3, architecture):
-            return getattr(tx.models.qwen3, architecture)
-        if hasattr(tx.models.deepseekv3, architecture):
-            return getattr(tx.models.deepseekv3, architecture)
+        if hasattr(skyrl.tx.models.llama3, architecture):
+            return getattr(skyrl.tx.models.llama3, architecture)
+        if hasattr(skyrl.tx.models.qwen3, architecture):
+            return getattr(skyrl.tx.models.qwen3, architecture)
+        if hasattr(skyrl.tx.models.deepseekv3, architecture):
+            return getattr(skyrl.tx.models.deepseekv3, architecture)
 
     raise ValueError(f"None of the architectures {config.architectures} is currently supported.")

@erictang000
Copy link
Collaborator Author

updated backends diff:

diff -ru /tmp/diff_tx_backends/jax.py /tmp/diff_skyrl_backends/jax.py
--- /tmp/diff_tx_backends/jax.py        2026-02-11 21:57:34.214902589 +0000
+++ /tmp/diff_skyrl_backends/jax.py     2026-02-11 21:37:38.409502381 +0000
@@ -6,18 +6,18 @@
 
 In multi-host mode, process 0 (coordinator) runs the engine with JaxBackend,
 which broadcasts commands to workers. Workers run separately using `run_worker()`
-or by running this module directly with `python -m tx.tinker.backends.jax`.
+or by running this module directly with `python -m skyrl.backends.jax`.
 
 Usage:
     # Coordinator (process 0) - runs the full engine:
-    uv run -m tx.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
+    uv run -m skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend-config '{
         "coordinator_address": "localhost:7777",
         "num_processes": 2,
         ...
     }'
 
     # Workers (process 1+) - run only the worker loop (receives config from coordinator):
-    uv run -m tx.tinker.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
+    uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
 """
 
 import time
@@ -36,8 +36,8 @@
 from pydantic import BaseModel, Field, TypeAdapter
 from transformers import AutoTokenizer, PretrainedConfig
 
-from skyrl.models.configs import Qwen3Config
-from skyrl.layers.lora import clear_lora_adapter, init_lora_adapter
+from skyrl.tx.models.configs import Qwen3Config
+from skyrl.tx.layers.lora import clear_lora_adapter, init_lora_adapter
 from skyrl.tinker import types
 from skyrl.backends.backend import AbstractBackend
 from skyrl.backends.utils import pad, pad_batch, pad_to_fsdp
@@ -1069,7 +1069,7 @@
     """Entry point for running as a worker process."""
     import argparse
 
-    parser = argparse.ArgumentParser(description="SkyRL tx tinker worker process")
+    parser = argparse.ArgumentParser(description="SkyRL tinker worker process")
     parser.add_argument(
         "--coordinator-address",
         required=True,

@erictang000 erictang000 merged commit 7c7820e into NovaSky-AI:main Feb 11, 2026
3 of 5 checks passed
@erictang000 erictang000 deleted the skyrl_merge1 branch February 11, 2026 21:58
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 7 additional findings in Devin Review.

Open in Devin Review

tyler-griggs pushed a commit that referenced this pull request Feb 12, 2026
This is getting the CI to work after moving the files in
#1068
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1087"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants