Implement forward-only pass and populate metrics#1046
Conversation
| for sample_idx in range(request_data.num_samples): | ||
| for _ in range(request_data.num_samples): | ||
| all_prompts.append(prompt_tokens) | ||
| # Derive a unique seed per sample so that num_samples > 1 produces |
There was a problem hiding this comment.
This is an important fix, don't forget to revert the change before merging the PR
There was a problem hiding this comment.
Oh thank you, I was just hacking around a bit with Claude, it will be reverted
7acd89f to
98e7500
Compare
forward_backward() now returns total_loss, pg_loss, entropy_loss, and num_tokens from the dispatch worker. optim_step() returns grad_norm and learning_rate. These are consumed by tinker-cookbook scripts via OptimStepResponse.metrics. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
98e7500 to
16fc973
Compare
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements the forward pass in the SkyRL-Train backend and adds metric reporting to forward_backward and optim_step. The changes are well-structured and align with the goals outlined in the description. I've suggested a minor refactoring in _extract_metrics to improve maintainability by using a dictionary for metric mapping, which will make it easier to add new metrics in the future. Overall, this is a solid contribution.
| def _extract_metrics(self, data: dict) -> dict[str, float]: | ||
| """Extract training metrics from dispatch return dict. | ||
|
|
||
| Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc. | ||
| We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum'). | ||
| """ | ||
| metrics: dict[str, float] = {} | ||
|
|
||
| # SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss' | ||
| if "loss" in data: | ||
| metrics["total_loss:sum"] = float(data["loss"]) | ||
| elif "final_loss" in data: | ||
| metrics["total_loss:sum"] = float(data["final_loss"]) | ||
|
|
||
| if "policy_loss" in data: | ||
| metrics["pg_loss:sum"] = float(data["policy_loss"]) | ||
| if "policy_entropy" in data: | ||
| metrics["entropy_loss:sum"] = float(data["policy_entropy"]) | ||
| if "response_length" in data: | ||
| metrics["num_tokens:sum"] = float(data["response_length"]) | ||
|
|
||
| return metrics |
There was a problem hiding this comment.
The current implementation of _extract_metrics uses a series of if statements to map metric names. While this works, it can become less maintainable as the number of metrics grows. Using a dictionary to define the mapping from source metric names to target metric names would make the code more scalable and easier to read.
def _extract_metrics(self, data: dict) -> dict[str, float]:
"""Extract training metrics from dispatch return dict.
Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc.
We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum').
"""
metrics: dict[str, float] = {}
metric_mapping = {
"policy_loss": "pg_loss:sum",
"policy_entropy": "entropy_loss:sum",
"response_length": "num_tokens:sum",
}
# SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
if "loss" in data:
metrics["total_loss:sum"] = float(data["loss"])
elif "final_loss" in data:
metrics["total_loss:sum"] = float(data["final_loss"])
for source_key, target_key in metric_mapping.items():
if source_key in data:
metrics[target_key] = float(data[source_key])
return metrics
Summary
forward()in SkyRL-Train backend (was raisingNotImplementedError)forward_backward()now returns training metrics (total_loss:sum,pg_loss:sum,entropy_loss:sum,num_tokens:sum) extracted from the dispatch worker return dictoptim_step()now returnsgrad_normandlearning_ratein its metrics fieldmetricsfield toOptimStepOutputtype