Skip to content

Commit 66d2b3a

Browse files
authored
fix(sft_trainer): Fix global_tokens and total_tokens metrics always showing 0.0 (#4854)
1 parent 56b7f63 commit 66d2b3a

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

verl/trainer/sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ def fit(self):
313313
data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info)
314314
batch_seqlens = self._get_batch_seqlens(data=data)
315315
# this is necessary. Otherwise, it is interpreted as NonTensorStack
316-
batch_seqlens = NonTensorData(batch_seqlens)
316+
batch_seqlens_ntd = NonTensorData(batch_seqlens)
317317

318-
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens)
318+
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens_ntd)
319319

320320
# start profile in SPMD mode
321321
if global_step == self.start_profile_step:

verl/trainer/sft_trainer_ray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,11 @@ def fit(self):
291291
global_step += 1
292292
# construct tensordict
293293
data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info)
294-
batch_seqlens = self._get_batch_seqlens(data=data)
294+
batch_seqlens = self._get_batch_seqlens(data=data).tolist()
295295
# this is necessary. Otherwise, it is interpreted as NonTensorStack
296-
batch_seqlens = NonTensorData(batch_seqlens.tolist())
296+
batch_seqlens_ntd = NonTensorData(batch_seqlens)
297297

298-
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens)
298+
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens_ntd)
299299

300300
# start profile in SPMD mode
301301
if global_step == self.start_profile_step:

0 commit comments

Comments
 (0)