Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from skyrl_train.generators.utils import get_rollout_metrics, get_response_ids_and_loss_mask_from_messages
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.base import ConversationType
from skyrl_train.utils.rate_limiter import create_rate_limiter
from omegaconf import DictConfig, OmegaConf
from harbor.trial.trial import Trial
from harbor.models.trial.config import TrialConfig
Expand Down Expand Up @@ -88,6 +89,10 @@ def __init__(
else:
self.custom_chat_template_content = None

# Initialize rate limiter
rate_limit_config = terminal_bench_cfg.get("rate_limit", None)
self._rate_limiter = create_rate_limiter(rate_limit_config)
Comment on lines +92 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 SkyRL-specific rate_limit config leaks into Harbor's TrialConfig.model_validate(), potentially causing validation errors

When a user enables rate limiting by adding a rate_limit section to their terminal bench YAML config, this SkyRL-specific field is included in _harbor_config_template (created at terminal_bench_generator.py:63 via OmegaConf.to_container(terminal_bench_cfg, resolve=True)) and then passed to Harbor's TrialConfig.model_validate(config) at terminal_bench_generator.py:169.

Root Cause and Impact

At line 63, the entire terminal_bench_cfg is converted to a plain dict:

self._harbor_config_template = OmegaConf.to_container(terminal_bench_cfg, resolve=True)

At lines 92-94, rate_limit is read from the same config but NOT removed from the template:

rate_limit_config = terminal_bench_cfg.get("rate_limit", None)
self._rate_limiter = create_rate_limiter(rate_limit_config)

Later, in terminal_bench_agent_loop at line 166-169, the template (still containing rate_limit) is deep-copied and passed to Harbor:

config = deepcopy(self._harbor_config_template)
config["task"] = {"path": prompt}
config["agent"]["kwargs"]["session_id"] = uuid4().hex
trial_config = TrialConfig.model_validate(config)

If Harbor's TrialConfig Pydantic model is configured with extra='forbid' (a common setting for strict config validation), this will raise a ValidationError on every trial, causing all trajectories to fail. Even with extra='ignore', passing unrelated config to an external library's validator is unintended and fragile.

Impact: When a user enables rate limiting (the primary feature of this PR), every trial could fail with a Pydantic validation error, making the rate limiting feature unusable.

Suggested change
# Initialize rate limiter
rate_limit_config = terminal_bench_cfg.get("rate_limit", None)
self._rate_limiter = create_rate_limiter(rate_limit_config)
# Initialize rate limiter
rate_limit_config = terminal_bench_cfg.get("rate_limit", None)
self._rate_limiter = create_rate_limiter(rate_limit_config)
self._harbor_config_template.pop("rate_limit", None)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
tasks = []
for i in range(len(input_batch["prompts"])):
Expand Down Expand Up @@ -174,7 +179,8 @@ async def terminal_bench_agent_loop(
prefix = f"Trajectory {trajectory_id} attempt {i+1}/{MAX_NUM_RETRIES_PER_TRIAL}"
results = None
try:
results = await trial.run()
async with self._rate_limiter:
results = await trial.run()
if not results.verifier_result:
logger.warning(f"{prefix} failed: Exception info: {results.exception_info}. Results: {results}")
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,13 @@ verifier:
# Override the verifier timeout (seconds)
# override_timeout_sec: 300

# Rate and concurrency limiting configuration (SkyRL-specific, not passed to Harbor)
rate_limit:
# Enable rate/concurrency limiting for trajectory submissions
enabled: true
# Maximum trajectories per second (must be >= 1.0, fractional values like 1.5 are supported)
# Set to null or omit to disable rate limiting
trajectories_per_second: 5
# Maximum concurrent trial.run() calls allowed (must be >= 1)
# Set to null or omit to disable concurrency limiting
max_concurrency: 512
224 changes: 224 additions & 0 deletions skyrl-train/skyrl_train/utils/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
Rate limiter for controlling trajectory submission rates and concurrency.

This module provides a token bucket rate limiter and concurrency limiter for async code,
allowing users to express "N trajectories per second" and "max M concurrent trajectories".

Note: Fractional rates >= 1.0 are supported (e.g., 1.5 trajectories/second).
Rates < 1.0 are not supported due to the token bucket implementation.
"""

import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union

from loguru import logger


@dataclass
class RateLimiterConfig:
"""Configuration for rate and concurrency limiting.

Attributes:
enabled: Whether limiting is enabled. If False, no limits are applied.
trajectories_per_second: Maximum trajectories per second (rate limiting).
Must be >= 1.0 if set. None means no rate limiting.
max_concurrency: Maximum concurrent trajectories allowed.
Must be >= 1 if set. None means no concurrency limiting.
"""

enabled: bool = False
trajectories_per_second: Optional[float] = None
max_concurrency: Optional[int] = None

def __post_init__(self):
if self.trajectories_per_second is not None and self.trajectories_per_second < 1.0:
raise ValueError(
f"trajectories_per_second must be >= 1.0, got {self.trajectories_per_second}. "
"Rates < 1.0 are not supported due to the token bucket implementation."
)
if self.max_concurrency is not None and self.max_concurrency < 1:
raise ValueError(f"max_concurrency must be >= 1, got {self.max_concurrency}")


class RateLimiterInterface(ABC):
"""Abstract base class for rate limiters."""

@abstractmethod
async def acquire(self) -> None:
"""Acquire permission to proceed. May block if rate/concurrency limited."""
pass

@abstractmethod
def release(self) -> None:
"""Release a concurrency slot. Must be called after operation completes."""
pass

async def __aenter__(self) -> "RateLimiterInterface":
"""Context manager entry: acquire permission."""
await self.acquire()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit: release concurrency slot."""
self.release()


class NoOpRateLimiter(RateLimiterInterface):
"""A no-op rate limiter that never blocks."""

async def acquire(self) -> None:
"""Immediately returns without blocking."""
pass

def release(self) -> None:
"""No-op release."""
pass


class AsyncRateLimiter(RateLimiterInterface):
"""Combined rate limiter and concurrency limiter for async code.

Rate limiting (token bucket algorithm):
- Bucket holds tokens, max capacity = rate (trajectories_per_second)
- Tokens refill at rate N per second
- Each acquire() call consumes 1 token
- If no tokens available, caller waits until refill

Concurrency limiting (semaphore):
- Limits how many operations can run simultaneously
- acquire() waits if max_concurrency operations are already running
- release() must be called when operation completes

Note: Fractional rates >= 1.0 are supported (e.g., 1.5 means 1.5 ops/second).
Rates < 1.0 are not supported because the bucket capacity equals the rate,
so the bucket could never hold a full token to allow an acquire().
"""

def __init__(
self,
rate: Optional[float] = None,
max_concurrency: Optional[int] = None,
):
"""Initialize the rate limiter.

Args:
rate: Maximum operations per second (tokens per second).
Must be >= 1.0 if provided. None disables rate limiting.
max_concurrency: Maximum concurrent operations allowed.
Must be >= 1 if provided. None disables concurrency limiting.
"""
if rate is not None and rate < 1.0:
raise ValueError(
f"Rate must be >= 1.0, got {rate}. "
"Rates < 1.0 are not supported due to the token bucket implementation."
)
if max_concurrency is not None and max_concurrency < 1:
raise ValueError(f"max_concurrency must be >= 1, got {max_concurrency}")

# Rate limiting state (token bucket)
self._rate = rate
if rate is not None:
self._max_tokens = rate # bucket capacity = rate
self._tokens = rate # start with a full bucket
self._last_refill = time.monotonic()
self._rate_lock = asyncio.Lock()

# Concurrency limiting state (semaphore)
self._max_concurrency = max_concurrency
if max_concurrency is not None:
self._semaphore = asyncio.Semaphore(max_concurrency)

async def acquire(self) -> None:
"""Acquire permission to proceed, waiting if necessary.

First applies rate limiting (controls how fast operations start),
then acquires concurrency slot (controls how many run simultaneously).
"""
# Rate limit first (controls start rate)
if self._rate is not None:
await self._acquire_rate_token()

# Then concurrency limit (controls concurrent execution)
if self._max_concurrency is not None:
await self._semaphore.acquire()

def release(self) -> None:
"""Release a concurrency slot.

Must be called after the operation completes to allow other
operations to proceed. Safe to call even if concurrency limiting
is disabled.
"""
if self._max_concurrency is not None:
self._semaphore.release()

async def _acquire_rate_token(self) -> None:
"""Acquire a rate limit token, waiting if necessary."""
async with self._rate_lock:
while True:
self._refill()
if self._tokens >= 1.0:
self._tokens -= 1.0
return
# Calculate wait time for next token
wait_time = (1.0 - self._tokens) / self._rate
# Release lock while sleeping so _refill time stays accurate
self._rate_lock.release()
try:
await asyncio.sleep(wait_time)
finally:
await self._rate_lock.acquire()

def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
now = time.monotonic()
elapsed = now - self._last_refill
self._tokens = min(self._max_tokens, self._tokens + elapsed * self._rate)
self._last_refill = now


def create_rate_limiter(config: Union[RateLimiterConfig, dict, None]) -> RateLimiterInterface:
"""Factory function to create a rate limiter from config.

Args:
config: Rate limiter configuration. Can be:
- RateLimiterConfig dataclass
- dict with 'enabled', 'trajectories_per_second', and/or 'max_concurrency' keys
- None (returns NoOpRateLimiter)

Returns:
AsyncRateLimiter if enabled with at least one limit configured,
NoOpRateLimiter otherwise.
"""
if config is None:
return NoOpRateLimiter()

if isinstance(config, dict):
config = RateLimiterConfig(
enabled=config.get("enabled", False),
trajectories_per_second=config.get("trajectories_per_second"),
max_concurrency=config.get("max_concurrency"),
)

if not config.enabled:
return NoOpRateLimiter()

# Log what's enabled
limits = []
if config.trajectories_per_second is not None:
limits.append(f"{config.trajectories_per_second} trajectories/second")
if config.max_concurrency is not None:
limits.append(f"max {config.max_concurrency} concurrent")
if limits:
logger.info(f"Rate limiter enabled: {', '.join(limits)}")
else:
# enabled=True but no limits configured, treat as no-op
return NoOpRateLimiter()

return AsyncRateLimiter(
rate=config.trajectories_per_second,
max_concurrency=config.max_concurrency,
)
Loading