File tree Expand file tree Collapse file tree 2 files changed +5
-5
lines changed
Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -313,9 +313,9 @@ def fit(self):
313313 data = tu .get_tensordict (tensor_dict = data , non_tensor_dict = meta_info )
314314 batch_seqlens = self ._get_batch_seqlens (data = data )
315315 # this is necessary. Otherwise, it is interpreted as NonTensorStack
316- batch_seqlens = NonTensorData (batch_seqlens )
316+ batch_seqlens_ntd = NonTensorData (batch_seqlens )
317317
318- tu .assign_non_tensor (data , update_lr_scheduler = True , global_token_num = batch_seqlens )
318+ tu .assign_non_tensor (data , update_lr_scheduler = True , global_token_num = batch_seqlens_ntd )
319319
320320 # start profile in SPMD mode
321321 if global_step == self .start_profile_step :
Original file line number Diff line number Diff line change @@ -291,11 +291,11 @@ def fit(self):
291291 global_step += 1
292292 # construct tensordict
293293 data = tu .get_tensordict (tensor_dict = data , non_tensor_dict = meta_info )
294- batch_seqlens = self ._get_batch_seqlens (data = data )
294+ batch_seqlens = self ._get_batch_seqlens (data = data ). tolist ()
295295 # this is necessary. Otherwise, it is interpreted as NonTensorStack
296- batch_seqlens = NonTensorData (batch_seqlens . tolist () )
296+ batch_seqlens_ntd = NonTensorData (batch_seqlens )
297297
298- tu .assign_non_tensor (data , update_lr_scheduler = True , global_token_num = batch_seqlens )
298+ tu .assign_non_tensor (data , update_lr_scheduler = True , global_token_num = batch_seqlens_ntd )
299299
300300 # start profile in SPMD mode
301301 if global_step == self .start_profile_step :
You can’t perform that action at this time.
0 commit comments