@@ -113,32 +113,32 @@ def training_log(
113113 if is_last_rank ():
114114
115115 for key in loss_dict :
116- iter_dict [f"{ name } / { key } " ] = loss_dict [key ]
117- consumed_train_samples_dict [f" { name } /" + key + " vs samples" ] = loss_dict [
116+ iter_dict [f"{ key } " ] = loss_dict [key ]
117+ consumed_train_samples_dict [key + " vs samples" ] = loss_dict [
118118 key
119119 ]
120120
121121 if grad_norm is not None :
122- iter_dict [f" { name } /" + "grad_norm" ] = grad_norm
123- consumed_train_samples_dict [f" { name } /" + "grad-norm vs samples" ] = grad_norm
122+ iter_dict ["grad_norm" ] = grad_norm
123+ consumed_train_samples_dict ["grad-norm vs samples" ] = grad_norm
124124
125125 if more_grad_norm is not None :
126126 for k in more_grad_norm :
127- iter_dict [f"{ name } / { k } " + " grad_norm" ] = more_grad_norm [k ]
128- consumed_train_samples_dict [f"{ name } / { k } " + " grad-norm vs samples" ] = (
127+ iter_dict [f"{ k } " + " grad_norm" ] = more_grad_norm [k ]
128+ consumed_train_samples_dict [f"{ k } " + " grad-norm vs samples" ] = (
129129 more_grad_norm [k ]
130130 )
131131
132132 if params_norm is not None :
133- iter_dict [f" { name } /" + "params-norm" ] = params_norm
134- consumed_train_samples_dict [f" { name } /" + "params-norm vs samples" ] = (
133+ iter_dict ["params-norm" ] = params_norm
134+ consumed_train_samples_dict ["params-norm vs samples" ] = (
135135 params_norm
136136 )
137137
138138 elapsed_time = 0
139139 elapsed_time_per_iteration = elapsed_time / total_iterations
140140 if args .log_timers_to_tensorboard :
141- iter_dict [f" { name } /" + "iteration-time" ] = elapsed_time_per_iteration
141+ iter_dict ["iteration-time" ] = elapsed_time_per_iteration
142142
143143 log_string = " iteration {:8d}/infinity |" .format (iteration )
144144 log_string += " consumed samples: {:12d} |" .format (args .consumed_train_samples )
@@ -560,9 +560,11 @@ def forward_step(data_iterator, model, *, is_training: bool=False, is_packing: b
560560 'input_ids' : inputs ["all_tokens" ],
561561 'position_ids' : inputs ["all_token_position_ids" ],
562562 'labels' : inputs ["labels" ] if not is_training else None ,
563- 'packed_seq_params' : inputs ['packed_seq_params' ] if is_packing else None
564563 }
565564
565+ if is_packing :
566+ kwargs .update ({'packed_seq_params' : inputs ['packed_seq_params' ]})
567+
566568 if 'pixel_values' in inputs :
567569 kwargs .update ({
568570 'vision_data' : inputs ["pixel_values" ],
0 commit comments