@@ -57,7 +57,8 @@ def __init__(self, variable_initial_values, run_single_process=False):
5757 )
5858 self .run_single_process = run_single_process
5959
60- def forward (self , input1 , input2 ):
60+ def forward (self , input1 , input2 , extra_input1 = None , extra_input2 = None ):
61+ # extra_input1 and extra_input2 only used for test non_tensor input in shard_dataloader
6162 x = input1 + input2
6263 # x: [bs, seq_len, hidden]
6364 # forward on mesh0
@@ -101,7 +102,7 @@ def __len__(self):
101102 return self .num_samples
102103
103104
104- def create_dataloader ():
105+ def create_dataloader (collate_fn = None ):
105106 dataset = RandomDataset (SEQ_LEN , HIDDEN_SIZE )
106107 sampler = BatchSampler (
107108 dataset ,
@@ -110,6 +111,7 @@ def create_dataloader():
110111 dataloader = DataLoader (
111112 dataset ,
112113 batch_sampler = sampler ,
114+ collate_fn = collate_fn ,
113115 )
114116 return dataloader
115117
@@ -204,8 +206,48 @@ def test_basic(self):
204206 loss .numpy (), self .single_process_loss , rtol = 1e-06 , verbose = True
205207 )
206208
209+ def test_non_tensor_input (self ):
210+ model = MlpModel (variable_initial_values = self .variable_initial_values )
211+ opt = paddle .optimizer .AdamW (
212+ learning_rate = 0.001 , parameters = model .parameters ()
213+ )
214+
215+ def custom_collate_fn (batch ):
216+ collated_batch = {
217+ "inputs" : [
218+ paddle .to_tensor ([item ["inputs" ][0 ] for item in batch ]),
219+ paddle .to_tensor ([item ["inputs" ][1 ] for item in batch ]),
220+ 12.0 ,
221+ ],
222+ "extra_input" : 12 ,
223+ "label" : paddle .to_tensor ([item ["label" ] for item in batch ]),
224+ }
225+ return collated_batch
226+
227+ self .dataloader = create_dataloader (custom_collate_fn )
228+
229+ dist_dataloader = dist .shard_dataloader (
230+ dataloader = self .dataloader ,
231+ meshes = [mesh0 , mesh0 , mesh1 ],
232+ shard_dims = "dp" ,
233+ input_keys = ["inputs" , "extra_input" , "label" ],
234+ )
235+
236+ dist_opt = dist .shard_optimizer (opt )
237+ for step , data in enumerate (dist_dataloader ()):
238+ input1 , input2 , extra_input1 = data ["inputs" ]
239+ extra_input2 = data ["extra_input" ]
240+ logits = model (input1 , input2 , extra_input1 , extra_input2 )
241+ label = data ["label" ]
242+ loss = loss_fn (logits , label )
243+ loss .backward ()
244+ dist_opt .step ()
245+ dist_opt .clear_grad ()
246+
207247 def run_test_case (self ):
208248 self .test_basic ()
249+ if not self ._run_static :
250+ self .test_non_tensor_input ()
209251
210252
211253if __name__ == '__main__' :
0 commit comments