Skip to content

Commit 62c1317

Browse files
authored
[Auto-Paralllel] fix shard_dataloader with no-tensor (#75252) (#75906)
1 parent 100cd2c commit 62c1317

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3954,6 +3954,8 @@ def __len__(self):
39543954
return len(self._dataloader)
39553955

39563956
def __iter__(self):
3957+
# Reset iterator state to allow restarting iteration
3958+
self.iter = None
39573959
return self
39583960

39593961
def _get_mesh_and_placement(self, index):
@@ -4007,7 +4009,9 @@ def _dtensors_from_list_input(
40074009
):
40084010
dist_data = []
40094011
for j in range(len(list_tensors)):
4010-
if dense_tensor_idx is not None and j in dense_tensor_idx:
4012+
if (
4013+
dense_tensor_idx is not None and j in dense_tensor_idx
4014+
) or not isinstance(list_tensors[j], paddle.Tensor):
40114015
dist_data.append(list_tensors[j])
40124016
else:
40134017
dist_data.append(
@@ -4095,9 +4099,7 @@ def _get_batch(self, batch_data):
40954099
batch_data[key], mesh, placements
40964100
)
40974101
else:
4098-
raise ValueError(
4099-
f"Unsupported input_data type {type(input_data)}"
4100-
)
4102+
dist_batch_data[key] = input_data
41014103
return dist_batch_data
41024104
elif isinstance(batch_data, paddle.Tensor):
41034105
mesh, placements = self._get_mesh_and_placement(0)
@@ -4112,7 +4114,8 @@ def __next__(self):
41124114
return self._get_batch(batch_data)
41134115

41144116
def __call__(self):
4115-
self.iter = self._dataloader.__iter__()
4117+
# Reset iterator state to allow restarting iteration
4118+
self.iter = None
41164119
return self
41174120

41184121

test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def __init__(self, variable_initial_values, run_single_process=False):
5757
)
5858
self.run_single_process = run_single_process
5959

60-
def forward(self, input1, input2):
60+
def forward(self, input1, input2, extra_input1=None, extra_input2=None):
61+
# extra_input1 and extra_input2 only used for test non_tensor input in shard_dataloader
6162
x = input1 + input2
6263
# x: [bs, seq_len, hidden]
6364
# forward on mesh0
@@ -101,7 +102,7 @@ def __len__(self):
101102
return self.num_samples
102103

103104

104-
def create_dataloader():
105+
def create_dataloader(collate_fn=None):
105106
dataset = RandomDataset(SEQ_LEN, HIDDEN_SIZE)
106107
sampler = BatchSampler(
107108
dataset,
@@ -110,6 +111,7 @@ def create_dataloader():
110111
dataloader = DataLoader(
111112
dataset,
112113
batch_sampler=sampler,
114+
collate_fn=collate_fn,
113115
)
114116
return dataloader
115117

@@ -204,8 +206,48 @@ def test_basic(self):
204206
loss.numpy(), self.single_process_loss, rtol=1e-06, verbose=True
205207
)
206208

209+
def test_non_tensor_input(self):
210+
model = MlpModel(variable_initial_values=self.variable_initial_values)
211+
opt = paddle.optimizer.AdamW(
212+
learning_rate=0.001, parameters=model.parameters()
213+
)
214+
215+
def custom_collate_fn(batch):
216+
collated_batch = {
217+
"inputs": [
218+
paddle.to_tensor([item["inputs"][0] for item in batch]),
219+
paddle.to_tensor([item["inputs"][1] for item in batch]),
220+
12.0,
221+
],
222+
"extra_input": 12,
223+
"label": paddle.to_tensor([item["label"] for item in batch]),
224+
}
225+
return collated_batch
226+
227+
self.dataloader = create_dataloader(custom_collate_fn)
228+
229+
dist_dataloader = dist.shard_dataloader(
230+
dataloader=self.dataloader,
231+
meshes=[mesh0, mesh0, mesh1],
232+
shard_dims="dp",
233+
input_keys=["inputs", "extra_input", "label"],
234+
)
235+
236+
dist_opt = dist.shard_optimizer(opt)
237+
for step, data in enumerate(dist_dataloader()):
238+
input1, input2, extra_input1 = data["inputs"]
239+
extra_input2 = data["extra_input"]
240+
logits = model(input1, input2, extra_input1, extra_input2)
241+
label = data["label"]
242+
loss = loss_fn(logits, label)
243+
loss.backward()
244+
dist_opt.step()
245+
dist_opt.clear_grad()
246+
207247
def run_test_case(self):
208248
self.test_basic()
249+
if not self._run_static:
250+
self.test_non_tensor_input()
209251

210252

211253
if __name__ == '__main__':

0 commit comments

Comments
 (0)