Skip to content

Commit 5eb2066

Browse files
committed
Update
[ghstack-poisoned]
1 parent ad2ce62 commit 5eb2066

File tree

2 files changed

+161
-137
lines changed

2 files changed

+161
-137
lines changed

test/float8/test_dtensor.py

Lines changed: 23 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
TODO(future): make this run in CI
1111
"""
1212

13-
import copy
1413
import os
1514

1615
import pytest
@@ -23,12 +22,6 @@
2322

2423
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
2524
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
26-
from torch.distributed.tensor.parallel import (
27-
ColwiseParallel,
28-
PrepareModuleInput,
29-
RowwiseParallel,
30-
parallelize_module,
31-
)
3225
from torch.testing._internal.distributed._tensor.common_dtensor import (
3326
ModelArgs,
3427
Transformer,
@@ -50,14 +43,11 @@
5043
LinearMMConfig,
5144
hp_tensor_and_scale_to_float8,
5245
)
53-
from torchao.float8.float8_tensor_parallel import (
54-
Float8ColwiseParallel,
55-
Float8RowwiseParallel,
56-
PrepareFloat8ModuleInput,
57-
)
5846
from torchao.float8.float8_utils import tensor_to_scale
5947
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
60-
from torchao.testing.training.dtensor_utils import ToyModel
48+
from torchao.testing.training.dtensor_utils import (
49+
_test_lowp_mlp_tensor_parallelism_base,
50+
)
6151

6252
torch.set_float32_matmul_precision("high")
6353

@@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
193183
loss.backward()
194184

195185

196-
def _test_fp8_mlp_tensor_parallelism_base(
197-
mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False
198-
):
199-
device = mesh.device_type
200-
201-
if rowwise:
202-
config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
203-
# hack around config being frozen
204-
# TODO(future PR): we should make this nicer at the config level
205-
object.__setattr__(config, "emulate", True)
206-
else:
207-
config = Float8LinearConfig(emulate=True)
208-
209-
toy_model = ToyModel().to(device)
210-
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
211-
212-
tp_model = copy.deepcopy(toy_model)
213-
tp_model = convert_to_float8_training(tp_model, config=config)
214-
sp_model = copy.deepcopy(toy_model)
215-
sp_model = convert_to_float8_training(sp_model, config=config)
216-
217-
# For tensorwise scaling, enable float8 all_gather.
218-
# For rowwise scaling, keep high precision all_gather. Motivation for
219-
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
220-
# so for float8 all-gather we'd need to send two float8 copies per tensor,
221-
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
222-
if rowwise:
223-
colwise_parallel_cls = ColwiseParallel
224-
rowwise_parallel_cls = RowwiseParallel
225-
prepare_input_cls = PrepareModuleInput
226-
else:
227-
colwise_parallel_cls = Float8ColwiseParallel
228-
rowwise_parallel_cls = Float8RowwiseParallel
229-
prepare_input_cls = PrepareFloat8ModuleInput
230-
231-
# vanilla TP
232-
tp_model = parallelize_module(
233-
tp_model,
234-
mesh,
235-
{
236-
"ffn.w1": colwise_parallel_cls(),
237-
"ffn.w2": colwise_parallel_cls(),
238-
"ffn.out_proj": rowwise_parallel_cls(),
239-
},
186+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
187+
tensorwise_config = Float8LinearConfig(emulate=True)
188+
_test_lowp_mlp_tensor_parallelism_base(
189+
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
240190
)
241191

242-
# "sequence parallel" mlp computation
243-
sp_model = parallelize_module(
244-
sp_model,
245-
mesh,
246-
{
247-
"ffn": prepare_input_cls(
248-
input_layouts=Shard(1), desired_input_layouts=Replicate()
249-
),
250-
"ffn.w1": colwise_parallel_cls(),
251-
"ffn.w2": colwise_parallel_cls(),
252-
"ffn.out_proj": rowwise_parallel_cls(
253-
output_layouts=Shard(1), use_local_output=False
254-
),
255-
},
192+
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
193+
# hack around config being frozen
194+
# TODO(future PR): we should make this nicer at the config level
195+
object.__setattr__(rowwise_config, "emulate", True)
196+
_test_lowp_mlp_tensor_parallelism_base(
197+
mesh, rowwise_config, size, compile=False, allgather_in_lowp=False
256198
)
257199

258-
# prepare_input_cls with specific submodule fqn
259-
sp_model2 = copy.deepcopy(toy_model)
260-
sp_model2 = convert_to_float8_training(sp_model2, config=config)
261200

262-
if rowwise:
263-
prepare_input = prepare_input_cls(
264-
input_layouts=Shard(1),
265-
desired_input_layouts=Replicate(),
266-
)
267-
else:
268-
prepare_input = prepare_input_cls(
269-
input_layouts=Shard(1),
270-
desired_input_layouts=Replicate(),
271-
fwd_config_submodule_fqn="w2",
272-
)
273-
274-
sp_model2 = parallelize_module(
275-
sp_model2,
276-
mesh,
277-
{
278-
"ffn": prepare_input,
279-
"ffn.w1": colwise_parallel_cls(),
280-
"ffn.w2": colwise_parallel_cls(),
281-
"ffn.out_proj": rowwise_parallel_cls(
282-
output_layouts=Shard(1), use_local_output=False
283-
),
284-
},
285-
)
286-
287-
if compile:
288-
tp_model = torch.compile(tp_model)
289-
sp_model = torch.compile(sp_model)
290-
sp_model2 = torch.compile(sp_model2)
291-
292-
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
293-
x_fp32_tp_input = x_fp32.clone()
294-
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
295-
296-
tp_out = tp_model(x_fp32_tp_input)
297-
tp_out.sum().backward()
298-
sp_out = sp_model(x_fp32_sp_input)
299-
sp_out.sum().backward()
300-
global_out = toy_model_fp8(x_fp32)
301-
global_out.sum().backward()
302-
torch.testing.assert_close(tp_out, global_out)
303-
torch.testing.assert_close(sp_out.full_tensor(), global_out)
304-
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
305-
torch.testing.assert_close(
306-
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
201+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
202+
tensorwise_config = Float8LinearConfig(emulate=True)
203+
_test_lowp_mlp_tensor_parallelism_base(
204+
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True
307205
)
308206

309-
sp_out2 = sp_model2(x_fp32_sp_input)
310-
sp_out2.sum().backward()
311-
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
312-
torch.testing.assert_close(
313-
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
314-
)
315-
torch.testing.assert_close(
316-
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
207+
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
208+
# hack around config being frozen
209+
# TODO(future PR): we should make this nicer at the config level
210+
object.__setattr__(rowwise_config, "emulate", True)
211+
_test_lowp_mlp_tensor_parallelism_base(
212+
mesh, rowwise_config, size, compile=True, allgather_in_lowp=False
317213
)
318214

319215

320-
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
321-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False)
322-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True)
323-
324-
325-
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
326-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False)
327-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True)
328-
329-
330216
def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
331217
torch.manual_seed(42)
332218
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()

torchao/testing/training/dtensor_utils.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,27 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import copy
67

8+
import torch
79
import torch.nn as nn
810
import torch.nn.functional as F
11+
from torch.distributed._tensor import Replicate, Shard, distribute_tensor
12+
from torch.distributed.device_mesh import DeviceMesh
13+
from torch.distributed.tensor.parallel import (
14+
ColwiseParallel,
15+
PrepareModuleInput,
16+
RowwiseParallel,
17+
parallelize_module,
18+
)
19+
20+
from torchao.float8 import Float8LinearConfig
21+
from torchao.float8.float8_linear_utils import convert_to_float8_training
22+
from torchao.float8.float8_tensor_parallel import (
23+
Float8ColwiseParallel,
24+
Float8RowwiseParallel,
25+
PrepareFloat8ModuleInput,
26+
)
927

1028

1129
class FeedForward(nn.Module):
@@ -28,3 +46,123 @@ def __init__(self):
2846

2947
def forward(self, x):
3048
return self.ffn(x)
49+
50+
51+
def _test_lowp_mlp_tensor_parallelism_base(
52+
mesh: DeviceMesh,
53+
config: Float8LinearConfig,
54+
size=16,
55+
compile: bool = False,
56+
allgather_in_lowp: bool = False,
57+
):
58+
device = mesh.device_type
59+
60+
toy_model = ToyModel().to(device)
61+
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
62+
63+
tp_model = copy.deepcopy(toy_model)
64+
tp_model = convert_to_float8_training(tp_model, config=config)
65+
sp_model = copy.deepcopy(toy_model)
66+
sp_model = convert_to_float8_training(sp_model, config=config)
67+
68+
# For tensorwise scaling, enable float8 all_gather.
69+
# For rowwise scaling, keep high precision all_gather. Motivation for
70+
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
71+
# so for float8 all-gather we'd need to send two float8 copies per tensor,
72+
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
73+
if not allgather_in_lowp:
74+
colwise_parallel_cls = ColwiseParallel
75+
rowwise_parallel_cls = RowwiseParallel
76+
prepare_input_cls = PrepareModuleInput
77+
else:
78+
colwise_parallel_cls = Float8ColwiseParallel
79+
rowwise_parallel_cls = Float8RowwiseParallel
80+
prepare_input_cls = PrepareFloat8ModuleInput
81+
82+
# vanilla TP
83+
tp_model = parallelize_module(
84+
tp_model,
85+
mesh,
86+
{
87+
"ffn.w1": colwise_parallel_cls(),
88+
"ffn.w2": colwise_parallel_cls(),
89+
"ffn.out_proj": rowwise_parallel_cls(),
90+
},
91+
)
92+
93+
# "sequence parallel" mlp computation
94+
sp_model = parallelize_module(
95+
sp_model,
96+
mesh,
97+
{
98+
"ffn": prepare_input_cls(
99+
input_layouts=Shard(1), desired_input_layouts=Replicate()
100+
),
101+
"ffn.w1": colwise_parallel_cls(),
102+
"ffn.w2": colwise_parallel_cls(),
103+
"ffn.out_proj": rowwise_parallel_cls(
104+
output_layouts=Shard(1), use_local_output=False
105+
),
106+
},
107+
)
108+
109+
# prepare_input_cls with specific submodule fqn
110+
sp_model2 = copy.deepcopy(toy_model)
111+
sp_model2 = convert_to_float8_training(sp_model2, config=config)
112+
113+
if not allgather_in_lowp:
114+
prepare_input = prepare_input_cls(
115+
input_layouts=Shard(1),
116+
desired_input_layouts=Replicate(),
117+
)
118+
else:
119+
prepare_input = prepare_input_cls(
120+
input_layouts=Shard(1),
121+
desired_input_layouts=Replicate(),
122+
fwd_config_submodule_fqn="w2",
123+
)
124+
125+
sp_model2 = parallelize_module(
126+
sp_model2,
127+
mesh,
128+
{
129+
"ffn": prepare_input,
130+
"ffn.w1": colwise_parallel_cls(),
131+
"ffn.w2": colwise_parallel_cls(),
132+
"ffn.out_proj": rowwise_parallel_cls(
133+
output_layouts=Shard(1), use_local_output=False
134+
),
135+
},
136+
)
137+
138+
if compile:
139+
tp_model = torch.compile(tp_model)
140+
sp_model = torch.compile(sp_model)
141+
sp_model2 = torch.compile(sp_model2)
142+
143+
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
144+
x_fp32_tp_input = x_fp32.clone()
145+
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
146+
147+
tp_out = tp_model(x_fp32_tp_input)
148+
tp_out.sum().backward()
149+
sp_out = sp_model(x_fp32_sp_input)
150+
sp_out.sum().backward()
151+
global_out = toy_model_fp8(x_fp32)
152+
global_out.sum().backward()
153+
torch.testing.assert_close(tp_out, global_out)
154+
torch.testing.assert_close(sp_out.full_tensor(), global_out)
155+
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
156+
torch.testing.assert_close(
157+
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
158+
)
159+
160+
sp_out2 = sp_model2(x_fp32_sp_input)
161+
sp_out2.sum().backward()
162+
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
163+
torch.testing.assert_close(
164+
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
165+
)
166+
torch.testing.assert_close(
167+
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
168+
)

0 commit comments

Comments
 (0)