Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,10 @@ def _forward_backward_micro(
else:
valid_len = action_log_probs.shape[1]

start = max(action_log_probs.shape[1] - valid_len, 0)
loss_fn_outputs.append(
{
"logprobs": action_log_probs[i, start:].detach().cpu().tolist(),
"elementwise_loss": elementwise_loss[i, start:].detach().cpu().tolist(),
"logprobs": action_log_probs[i, :valid_len].detach().cpu().tolist(),
"elementwise_loss": elementwise_loss[i, :valid_len].detach().cpu().tolist(),
}
)

Expand Down Expand Up @@ -876,12 +875,33 @@ def _forward_backward_micro(
loss = policy_loss + kl_loss_term - entropy_loss_term
self.strategy.backward(loss, self.model, self.optimizer)

# Build per-sequence loss_fn_outputs with logprobs.
batch_size = action_log_probs.shape[0]
seq_len = action_log_probs.shape[1]

if action_mask is not None:
valid_lens = action_mask.sum(dim=1).int().tolist()
elif loss_mask is not None:
valid_lens = loss_mask.sum(dim=1).int().tolist()
else:
valid_lens = [seq_len] * batch_size

detached_log_probs = action_log_probs.detach().cpu()
loss_fn_outputs = []
for i, valid_len in enumerate(valid_lens):
loss_fn_outputs.append(
{
"logprobs": detached_log_probs[i, :valid_len].tolist(),
}
)

status = {
"final_loss": loss.item(),
"policy_loss": policy_loss.item(),
"policy_entropy": entropy.item(),
"response_length": num_actions,
"policy_lr": self.scheduler.get_last_lr()[0],
"loss_fn_outputs": loss_fn_outputs,
}
for k, v in loss_metrics.items():
status["loss_metrics/" + k] = v
Expand Down
56 changes: 56 additions & 0 deletions skyrl-train/tests/gpu/gpu_ci/test_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,69 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac
assert "policy_loss" in result
assert "loss_metrics/clip_ratio" in result
assert "policy_entropy" in result
assert "loss_fn_outputs" in result, "RL path should return loss_fn_outputs"
loss_fn_outputs = result.pop("loss_fn_outputs")
assert isinstance(loss_fn_outputs, list)
for output in loss_fn_outputs:
assert "logprobs" in output, "Each output should have logprobs"
assert isinstance(output["logprobs"], list)
for k, v in result.items():
assert isinstance(v, (int, float)), f"{k} should be an int or float"

finally:
ray.shutdown()


@pytest.mark.asyncio
async def test_policy_loss_fn_outputs_variable_lengths(ray_init_fixture, cfg):
"""
Verify that loss_fn_outputs logprobs are trimmed to the correct per-sample
valid length when samples have different response lengths (right-padded masks).

Uses variable action_lengths so each sample has a different number of valid
tokens, then checks that each output's logprobs length matches exactly.
"""
cfg.trainer.use_sample_packing = False
cfg.trainer.strategy = "fsdp2"
validate_cfg(cfg)

num_actions = 6
# 4 samples total, 2 per DP rank. Each pair has different valid lengths.
action_lengths = [3, 6, 2, 5]

try:
actor_group = init_worker_with_type(
"policy",
shared_pg=None,
colocate_all=False,
num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node,
cfg=cfg,
)

dp_size = actor_group.actor_infos[0].rank.dp_size
batch_size = dp_size * 2
dummy_batch = make_dummy_training_batch(
batch_size=batch_size, num_actions=num_actions, action_lengths=action_lengths
)

results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch))

# Collect all loss_fn_outputs across DP ranks (returned in rank order)
all_outputs = []
for result in results:
assert "loss_fn_outputs" in result
all_outputs.extend(result["loss_fn_outputs"])

assert len(all_outputs) == batch_size
for i, output in enumerate(all_outputs):
expected_len = action_lengths[i]
actual_len = len(output["logprobs"])
assert actual_len == expected_len, f"Sample {i}: expected {expected_len} logprobs, got {actual_len}"

finally:
ray.shutdown()


@pytest.mark.asyncio
@pytest.mark.parametrize(
("packed", "strategy"),
Expand Down
23 changes: 19 additions & 4 deletions skyrl-train/tests/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,26 @@ def make_dummy_tensorbatch(seq_len=10, num_actions=4) -> TensorBatch:
return data


def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) -> TrainingInputBatch:
"""Create a dummy TrainingInputBatch"""
def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4, action_lengths=None) -> TrainingInputBatch:
"""Create a dummy TrainingInputBatch.

Args:
action_lengths: Optional list of per-sample valid action lengths.
If provided, loss_mask and response_mask will be right-padded per
sample (1s then 0s). Length must equal batch_size. Each value must
be <= num_actions.
"""

torch.manual_seed(42)

loss_mask = torch.ones((batch_size, num_actions), dtype=int, device="cpu")
response_mask = torch.ones((batch_size, num_actions), dtype=int, device="cpu")
if action_lengths is not None:
assert len(action_lengths) == batch_size
for i, valid_len in enumerate(action_lengths):
loss_mask[i, valid_len:] = 0
response_mask[i, valid_len:] = 0

# Add all the required fields for training
data = TrainingInputBatch(
{
Expand All @@ -74,8 +89,8 @@ def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) -> Traini
"values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"),
"returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"),
"advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"),
"loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"),
"response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"),
"loss_mask": loss_mask,
"response_mask": response_mask,
"rollout_logprobs": 0.2 * torch.ones((batch_size, num_actions), device="cpu"),
}
)
Expand Down