Skip to content

Implement forward-only pass and populate metrics#1046

Merged
tyler-griggs merged 2 commits intomainfrom
tyler/populate-metrics-clean
Feb 7, 2026
Merged

Implement forward-only pass and populate metrics#1046
tyler-griggs merged 2 commits intomainfrom
tyler/populate-metrics-clean

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Feb 7, 2026

Summary

  • Implement forward() in SkyRL-Train backend (was raising NotImplementedError)
  • forward_backward() now returns training metrics (total_loss:sum, pg_loss:sum, entropy_loss:sum, num_tokens:sum) extracted from the dispatch worker return dict
  • optim_step() now returns grad_norm and learning_rate in its metrics field
  • Added metrics field to OptimStepOutput type

for sample_idx in range(request_data.num_samples):
for _ in range(request_data.num_samples):
all_prompts.append(prompt_tokens)
# Derive a unique seed per sample so that num_samples > 1 produces
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important fix, don't forget to revert the change before merging the PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thank you, I was just hacking around a bit with Claude, it will be reverted

forward_backward() now returns total_loss, pg_loss, entropy_loss, and
num_tokens from the dispatch worker.  optim_step() returns grad_norm
and learning_rate.  These are consumed by tinker-cookbook scripts via
OptimStepResponse.metrics.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@tyler-griggs tyler-griggs changed the title Populate metrics in forward_backward and optim_step Implement forward-only pass and populate metrics Feb 7, 2026
@tyler-griggs tyler-griggs marked this pull request as ready for review February 7, 2026 20:06
@tyler-griggs tyler-griggs merged commit 590685f into main Feb 7, 2026
6 checks passed
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the forward pass in the SkyRL-Train backend and adds metric reporting to forward_backward and optim_step. The changes are well-structured and align with the goals outlined in the description. I've suggested a minor refactoring in _extract_metrics to improve maintainability by using a dictionary for metric mapping, which will make it easier to add new metrics in the future. Overall, this is a solid contribution.

Comment on lines +231 to +252
def _extract_metrics(self, data: dict) -> dict[str, float]:
"""Extract training metrics from dispatch return dict.

Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc.
We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum').
"""
metrics: dict[str, float] = {}

# SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
if "loss" in data:
metrics["total_loss:sum"] = float(data["loss"])
elif "final_loss" in data:
metrics["total_loss:sum"] = float(data["final_loss"])

if "policy_loss" in data:
metrics["pg_loss:sum"] = float(data["policy_loss"])
if "policy_entropy" in data:
metrics["entropy_loss:sum"] = float(data["policy_entropy"])
if "response_length" in data:
metrics["num_tokens:sum"] = float(data["response_length"])

return metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _extract_metrics uses a series of if statements to map metric names. While this works, it can become less maintainable as the number of metrics grows. Using a dictionary to define the mapping from source metric names to target metric names would make the code more scalable and easier to read.

    def _extract_metrics(self, data: dict) -> dict[str, float]:
        """Extract training metrics from dispatch return dict.

        Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc.
        We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum').
        """
        metrics: dict[str, float] = {}
        metric_mapping = {
            "policy_loss": "pg_loss:sum",
            "policy_entropy": "entropy_loss:sum",
            "response_length": "num_tokens:sum",
        }

        # SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
        if "loss" in data:
            metrics["total_loss:sum"] = float(data["loss"])
        elif "final_loss" in data:
            metrics["total_loss:sum"] = float(data["final_loss"])

        for source_key, target_key in metric_mapping.items():
            if source_key in data:
                metrics[target_key] = float(data[source_key])

        return metrics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants