Skip to content

Commit 6625a3d

Browse files
JacoCheungclaude
authored andcommitted
fix: update test_tp_ranking_gr to match progress() 3-value return
progress() now returns (loss, global_tokens, extras) instead of (loss, extras). Update all 3 call sites in the TP test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d9dbe30 commit 6625a3d

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,15 @@ def test_tp_gr_ranking_forward_backward_update(
278278
tp_ranking_gr, debug_ranking_gr, debug_ranking_gr_fp32
279279
)
280280
for i, batch in enumerate(history_batches):
281-
_, (losses, logits, _, _) = debug_pipeline.progress(debug_pipeline_batches)
282-
_, (losses_fp32, logits_fp32, _, _) = debug_pipeline_fp32.progress(
281+
_, _, (losses, logits, _, _) = debug_pipeline.progress(
282+
debug_pipeline_batches
283+
)
284+
_, _, (losses_fp32, logits_fp32, _, _) = debug_pipeline_fp32.progress(
283285
debug_pipeline_batches_fp32
284286
)
285-
_, (tp_losses, tp_logits, _, _) = tp_pipeline.progress(iter_history_batches)
287+
_, _, (tp_losses, tp_logits, _, _) = tp_pipeline.progress(
288+
iter_history_batches
289+
)
286290
torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
287291
compare_tpN_to_debug_weights(
288292
tp_ranking_gr, debug_ranking_gr, debug_ranking_gr_fp32

0 commit comments

Comments
 (0)