Skip to content

make dtensor shared test util more generic #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/vkuzo/88/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 23 additions & 137 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
TODO(future): make this run in CI
"""

import copy
import os

import pytest
Expand All @@ -23,12 +22,6 @@

from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
parallelize_module,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
Expand All @@ -50,14 +43,11 @@
LinearMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.training.dtensor_utils import ToyModel
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
)

torch.set_float32_matmul_precision("high")

Expand Down Expand Up @@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
loss.backward()


def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False
):
device = mesh.device_type

if rowwise:
config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
else:
config = Float8LinearConfig(emulate=True)

toy_model = ToyModel().to(device)
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = convert_to_float8_training(sp_model, config=config)

# For tensorwise scaling, enable float8 all_gather.
# For rowwise scaling, keep high precision all_gather. Motivation for
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
# so for float8 all-gather we'd need to send two float8 copies per tensor,
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
if rowwise:
colwise_parallel_cls = ColwiseParallel
rowwise_parallel_cls = RowwiseParallel
prepare_input_cls = PrepareModuleInput
else:
colwise_parallel_cls = Float8ColwiseParallel
rowwise_parallel_cls = Float8RowwiseParallel
prepare_input_cls = PrepareFloat8ModuleInput

# vanilla TP
tp_model = parallelize_module(
tp_model,
mesh,
{
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(),
},
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
)

# "sequence parallel" mlp computation
sp_model = parallelize_module(
sp_model,
mesh,
{
"ffn": prepare_input_cls(
input_layouts=Shard(1), desired_input_layouts=Replicate()
),
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(rowwise_config, "emulate", True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, rowwise_config, size, compile=False, allgather_in_lowp=False
)

# prepare_input_cls with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = convert_to_float8_training(sp_model2, config=config)

if rowwise:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
)
else:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
fwd_config_submodule_fqn="w2",
)

sp_model2 = parallelize_module(
sp_model2,
mesh,
{
"ffn": prepare_input,
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
global_out = toy_model_fp8(x_fp32)
global_out.sum().backward()
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(rowwise_config, "emulate", True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, rowwise_config, size, compile=True, allgather_in_lowp=False
)


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False)
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False)
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True)


def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
torch.manual_seed(42)
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()
Expand Down
138 changes: 138 additions & 0 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,27 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import Replicate, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
parallelize_module,
)

from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)


class FeedForward(nn.Module):
Expand All @@ -28,3 +46,123 @@ def __init__(self):

def forward(self, x):
return self.ffn(x)


def _test_lowp_mlp_tensor_parallelism_base(
mesh: DeviceMesh,
config: Float8LinearConfig,
size=16,
compile: bool = False,
allgather_in_lowp: bool = False,
):
device = mesh.device_type

toy_model = ToyModel().to(device)
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = convert_to_float8_training(sp_model, config=config)

# For tensorwise scaling, enable float8 all_gather.
# For rowwise scaling, keep high precision all_gather. Motivation for
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
# so for float8 all-gather we'd need to send two float8 copies per tensor,
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
if not allgather_in_lowp:
colwise_parallel_cls = ColwiseParallel
rowwise_parallel_cls = RowwiseParallel
prepare_input_cls = PrepareModuleInput
else:
colwise_parallel_cls = Float8ColwiseParallel
rowwise_parallel_cls = Float8RowwiseParallel
prepare_input_cls = PrepareFloat8ModuleInput

# vanilla TP
tp_model = parallelize_module(
tp_model,
mesh,
{
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(),
},
)

# "sequence parallel" mlp computation
sp_model = parallelize_module(
sp_model,
mesh,
{
"ffn": prepare_input_cls(
input_layouts=Shard(1), desired_input_layouts=Replicate()
),
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
)

# prepare_input_cls with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = convert_to_float8_training(sp_model2, config=config)

if not allgather_in_lowp:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
)
else:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
fwd_config_submodule_fqn="w2",
)

sp_model2 = parallelize_module(
sp_model2,
mesh,
{
"ffn": prepare_input,
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
global_out = toy_model_fp8(x_fp32)
global_out.sum().backward()
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
)
Loading