|
10 | 10 | TODO(future): make this run in CI
|
11 | 11 | """
|
12 | 12 |
|
13 |
| -import copy |
14 | 13 | import os
|
15 | 14 |
|
16 | 15 | import pytest
|
|
23 | 22 |
|
24 | 23 | from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
|
25 | 24 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
26 |
| -from torch.distributed.tensor.parallel import ( |
27 |
| - ColwiseParallel, |
28 |
| - PrepareModuleInput, |
29 |
| - RowwiseParallel, |
30 |
| - parallelize_module, |
31 |
| -) |
32 | 25 | from torch.testing._internal.distributed._tensor.common_dtensor import (
|
33 | 26 | ModelArgs,
|
34 | 27 | Transformer,
|
|
50 | 43 | LinearMMConfig,
|
51 | 44 | hp_tensor_and_scale_to_float8,
|
52 | 45 | )
|
53 |
| -from torchao.float8.float8_tensor_parallel import ( |
54 |
| - Float8ColwiseParallel, |
55 |
| - Float8RowwiseParallel, |
56 |
| - PrepareFloat8ModuleInput, |
57 |
| -) |
58 | 46 | from torchao.float8.float8_utils import tensor_to_scale
|
59 | 47 | from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
|
60 |
| -from torchao.testing.training.dtensor_utils import ToyModel |
| 48 | +from torchao.testing.training.dtensor_utils import ( |
| 49 | + _test_lowp_mlp_tensor_parallelism_base, |
| 50 | +) |
61 | 51 |
|
62 | 52 | torch.set_float32_matmul_precision("high")
|
63 | 53 |
|
@@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
|
193 | 183 | loss.backward()
|
194 | 184 |
|
195 | 185 |
|
196 |
| -def _test_fp8_mlp_tensor_parallelism_base( |
197 |
| - mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False |
198 |
| -): |
199 |
| - device = mesh.device_type |
200 |
| - |
201 |
| - if rowwise: |
202 |
| - config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) |
203 |
| - # hack around config being frozen |
204 |
| - # TODO(future PR): we should make this nicer at the config level |
205 |
| - object.__setattr__(config, "emulate", True) |
206 |
| - else: |
207 |
| - config = Float8LinearConfig(emulate=True) |
208 |
| - |
209 |
| - toy_model = ToyModel().to(device) |
210 |
| - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) |
211 |
| - |
212 |
| - tp_model = copy.deepcopy(toy_model) |
213 |
| - tp_model = convert_to_float8_training(tp_model, config=config) |
214 |
| - sp_model = copy.deepcopy(toy_model) |
215 |
| - sp_model = convert_to_float8_training(sp_model, config=config) |
216 |
| - |
217 |
| - # For tensorwise scaling, enable float8 all_gather. |
218 |
| - # For rowwise scaling, keep high precision all_gather. Motivation for |
219 |
| - # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, |
220 |
| - # so for float8 all-gather we'd need to send two float8 copies per tensor, |
221 |
| - # which is similar # bytes over the wire than just doing bfloat16 all-gather. |
222 |
| - if rowwise: |
223 |
| - colwise_parallel_cls = ColwiseParallel |
224 |
| - rowwise_parallel_cls = RowwiseParallel |
225 |
| - prepare_input_cls = PrepareModuleInput |
226 |
| - else: |
227 |
| - colwise_parallel_cls = Float8ColwiseParallel |
228 |
| - rowwise_parallel_cls = Float8RowwiseParallel |
229 |
| - prepare_input_cls = PrepareFloat8ModuleInput |
230 |
| - |
231 |
| - # vanilla TP |
232 |
| - tp_model = parallelize_module( |
233 |
| - tp_model, |
234 |
| - mesh, |
235 |
| - { |
236 |
| - "ffn.w1": colwise_parallel_cls(), |
237 |
| - "ffn.w2": colwise_parallel_cls(), |
238 |
| - "ffn.out_proj": rowwise_parallel_cls(), |
239 |
| - }, |
| 186 | +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): |
| 187 | + tensorwise_config = Float8LinearConfig(emulate=True) |
| 188 | + _test_lowp_mlp_tensor_parallelism_base( |
| 189 | + mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True |
240 | 190 | )
|
241 | 191 |
|
242 |
| - # "sequence parallel" mlp computation |
243 |
| - sp_model = parallelize_module( |
244 |
| - sp_model, |
245 |
| - mesh, |
246 |
| - { |
247 |
| - "ffn": prepare_input_cls( |
248 |
| - input_layouts=Shard(1), desired_input_layouts=Replicate() |
249 |
| - ), |
250 |
| - "ffn.w1": colwise_parallel_cls(), |
251 |
| - "ffn.w2": colwise_parallel_cls(), |
252 |
| - "ffn.out_proj": rowwise_parallel_cls( |
253 |
| - output_layouts=Shard(1), use_local_output=False |
254 |
| - ), |
255 |
| - }, |
| 192 | + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) |
| 193 | + # hack around config being frozen |
| 194 | + # TODO(future PR): we should make this nicer at the config level |
| 195 | + object.__setattr__(rowwise_config, "emulate", True) |
| 196 | + _test_lowp_mlp_tensor_parallelism_base( |
| 197 | + mesh, rowwise_config, size, compile=False, allgather_in_lowp=False |
256 | 198 | )
|
257 | 199 |
|
258 |
| - # prepare_input_cls with specific submodule fqn |
259 |
| - sp_model2 = copy.deepcopy(toy_model) |
260 |
| - sp_model2 = convert_to_float8_training(sp_model2, config=config) |
261 | 200 |
|
262 |
| - if rowwise: |
263 |
| - prepare_input = prepare_input_cls( |
264 |
| - input_layouts=Shard(1), |
265 |
| - desired_input_layouts=Replicate(), |
266 |
| - ) |
267 |
| - else: |
268 |
| - prepare_input = prepare_input_cls( |
269 |
| - input_layouts=Shard(1), |
270 |
| - desired_input_layouts=Replicate(), |
271 |
| - fwd_config_submodule_fqn="w2", |
272 |
| - ) |
273 |
| - |
274 |
| - sp_model2 = parallelize_module( |
275 |
| - sp_model2, |
276 |
| - mesh, |
277 |
| - { |
278 |
| - "ffn": prepare_input, |
279 |
| - "ffn.w1": colwise_parallel_cls(), |
280 |
| - "ffn.w2": colwise_parallel_cls(), |
281 |
| - "ffn.out_proj": rowwise_parallel_cls( |
282 |
| - output_layouts=Shard(1), use_local_output=False |
283 |
| - ), |
284 |
| - }, |
285 |
| - ) |
286 |
| - |
287 |
| - if compile: |
288 |
| - tp_model = torch.compile(tp_model) |
289 |
| - sp_model = torch.compile(sp_model) |
290 |
| - sp_model2 = torch.compile(sp_model2) |
291 |
| - |
292 |
| - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) |
293 |
| - x_fp32_tp_input = x_fp32.clone() |
294 |
| - x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) |
295 |
| - |
296 |
| - tp_out = tp_model(x_fp32_tp_input) |
297 |
| - tp_out.sum().backward() |
298 |
| - sp_out = sp_model(x_fp32_sp_input) |
299 |
| - sp_out.sum().backward() |
300 |
| - global_out = toy_model_fp8(x_fp32) |
301 |
| - global_out.sum().backward() |
302 |
| - torch.testing.assert_close(tp_out, global_out) |
303 |
| - torch.testing.assert_close(sp_out.full_tensor(), global_out) |
304 |
| - torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) |
305 |
| - torch.testing.assert_close( |
306 |
| - tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad |
| 201 | +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): |
| 202 | + tensorwise_config = Float8LinearConfig(emulate=True) |
| 203 | + _test_lowp_mlp_tensor_parallelism_base( |
| 204 | + mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True |
307 | 205 | )
|
308 | 206 |
|
309 |
| - sp_out2 = sp_model2(x_fp32_sp_input) |
310 |
| - sp_out2.sum().backward() |
311 |
| - torch.testing.assert_close(sp_out2.full_tensor(), global_out) |
312 |
| - torch.testing.assert_close( |
313 |
| - tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad |
314 |
| - ) |
315 |
| - torch.testing.assert_close( |
316 |
| - tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad |
| 207 | + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) |
| 208 | + # hack around config being frozen |
| 209 | + # TODO(future PR): we should make this nicer at the config level |
| 210 | + object.__setattr__(rowwise_config, "emulate", True) |
| 211 | + _test_lowp_mlp_tensor_parallelism_base( |
| 212 | + mesh, rowwise_config, size, compile=True, allgather_in_lowp=False |
317 | 213 | )
|
318 | 214 |
|
319 | 215 |
|
320 |
| -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): |
321 |
| - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False) |
322 |
| - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True) |
323 |
| - |
324 |
| - |
325 |
| -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): |
326 |
| - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False) |
327 |
| - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True) |
328 |
| - |
329 |
| - |
330 | 216 | def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
|
331 | 217 | torch.manual_seed(42)
|
332 | 218 | model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()
|
|
0 commit comments