Skip to content

Commit 797eb91

Browse files
committed
Enable FP8 full finetune distributed
TODO: write this Based on #2404 by @nathan-az
1 parent f1ecdd6 commit 797eb91

File tree

3 files changed

+108
-29
lines changed

3 files changed

+108
-29
lines changed

recipes/full_finetune_distributed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -33,6 +34,7 @@
3334
TrainingProgress,
3435
)
3536
from torchtune.training.lr_schedulers import get_lr
37+
from torchtune.training.quantization import convert_to_float8_training
3638

3739
from tqdm import tqdm
3840

@@ -184,6 +186,8 @@ def __init__(self, cfg: DictConfig) -> None:
184186
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185187
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186188
self._checkpoint_client = CheckpointClient(cfg)
189+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
190+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
187191

188192
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
189193
if self._optimizer_in_bwd:
@@ -545,6 +549,11 @@ def _setup_model(
545549
if self._compile:
546550
training.compile_model(model, verbose=self._is_rank_zero)
547551

552+
if self._enable_fp8_training:
553+
# TODO: gate on nightlies?
554+
# TODO: validate self.tp_plan, if any, based on config
555+
model = convert_to_float8_training(model, self._fp8_recipe_name)
556+
548557
# Apply tensor parallelism to the model
549558
if self.parallel_dims.tp_enabled:
550559
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -846,6 +855,12 @@ def train(self) -> None:
846855
if self._lr_scheduler is not None:
847856
self._lr_scheduler.step()
848857

858+
# If float8 training is enabled, perform a single all-reduce to compute the
859+
# scale for all float8 parameters efficiently instead of doing many small
860+
# all-reduces for each parameter
861+
if self._enable_fp8_training and self.dp_degree > 1:
862+
precompute_float8_dynamic_scale_for_fsdp(self._model)
863+
849864
loss_to_log = running_loss.item() / num_tokens
850865
pbar.update(1)
851866
pbar.set_description(

torchtune/models/llama3/_parallelism.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch.distributed.tensor import Replicate, Shard
1010
from torch.distributed.tensor.parallel import (
@@ -15,32 +15,46 @@
1515
)
1616
from torch.distributed.tensor.parallel.style import ParallelStyle
1717

18+
from torchao.float8.float8_tensor_parallel import (
19+
Float8ColwiseParallel,
20+
Float8RowwiseParallel,
21+
)
1822

19-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20-
BASE_LLAMA_TP_PLAN = {
21-
"tok_embeddings": RowwiseParallel(
22-
input_layouts=Replicate(), output_layouts=Shard(1)
23-
),
24-
"norm": SequenceParallel(),
25-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
26-
"layers.*.attn": PrepareModuleInput(
27-
input_layouts=(Shard(1), None),
28-
desired_input_layouts=(Replicate(), None),
29-
),
30-
"layers.*.mlp": PrepareModuleInput(
31-
input_layouts=(Shard(1),),
32-
desired_input_layouts=(Replicate(),),
33-
),
34-
"layers.*.sa_norm": SequenceParallel(),
35-
"layers.*.mlp_norm": SequenceParallel(),
36-
"layers.*.attn.q_proj": ColwiseParallel(),
37-
"layers.*.attn.k_proj": ColwiseParallel(),
38-
"layers.*.attn.v_proj": ColwiseParallel(),
39-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
40-
"layers.*.mlp.w1": ColwiseParallel(),
41-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
42-
"layers.*.mlp.w3": ColwiseParallel(),
43-
}
23+
24+
def _get_base_llama_tp_plan(
25+
_sequence_parallel_cls: Type[ParallelStyle] = SequenceParallel,
26+
_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
27+
_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
28+
) -> Dict[str, ParallelStyle]:
29+
"""
30+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
31+
"""
32+
return {
33+
"tok_embeddings": _rowwise_parallel_cls(
34+
input_layouts=Replicate(), output_layouts=Shard(1)
35+
),
36+
"norm": _sequence_parallel_cls(),
37+
"output": _colwise_parallel_cls(
38+
input_layouts=Shard(1), output_layouts=Replicate()
39+
),
40+
"layers.*.attn": PrepareModuleInput(
41+
input_layouts=(Shard(1), None),
42+
desired_input_layouts=(Replicate(), None),
43+
),
44+
"layers.*.mlp": PrepareModuleInput(
45+
input_layouts=(Shard(1),),
46+
desired_input_layouts=(Replicate(),),
47+
),
48+
"layers.*.sa_norm": _sequence_parallel_cls(),
49+
"layers.*.mlp_norm": _sequence_parallel_cls(),
50+
"layers.*.attn.q_proj": _colwise_parallel_cls(),
51+
"layers.*.attn.k_proj": _colwise_parallel_cls(),
52+
"layers.*.attn.v_proj": _colwise_parallel_cls(),
53+
"layers.*.attn.output_proj": _rowwise_parallel_cls(output_layouts=Shard(1)),
54+
"layers.*.mlp.w1": _colwise_parallel_cls(),
55+
"layers.*.mlp.w2": _rowwise_parallel_cls(output_layouts=Shard(1)),
56+
"layers.*.mlp.w3": _colwise_parallel_cls(),
57+
}
4458

4559

4660
def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
@@ -50,4 +64,19 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5064
Returns:
5165
Dict[str, Any]: The tensor parallel plan for Llama3 model.
5266
"""
53-
return BASE_LLAMA_TP_PLAN
67+
return _get_base_llama_tp_plan()
68+
69+
70+
def fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
71+
"""
72+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
73+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
74+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
75+
76+
Returns:
77+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
78+
"""
79+
return _get_base_llama_tp_plan(
80+
_colwise_parallel_cls=Float8ColwiseParallel,
81+
_rowwise_parallel_cls=Float8RowwiseParallel,
82+
)

torchtune/training/quantization.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from torch import nn
1010

1111
from torchao.dtypes import TensorCoreTiledLayout
12-
12+
from torchao.float8 import (
13+
convert_to_float8_training as _convert_to_float8_training_torchao,
14+
Float8LinearConfig,
15+
)
1316
from torchao.quantization import (
1417
int4_weight_only,
1518
int8_dynamic_activation_int4_weight,
1619
quantize_,
1720
)
18-
1921
from torchao.quantization.qat import (
2022
Int4WeightOnlyQATQuantizer,
2123
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +28,7 @@
2628
enable_4w_fake_quant,
2729
enable_8da4w_fake_quant,
2830
)
31+
2932
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3033

3134

@@ -219,3 +222,35 @@ def swap_lora_linear_with_qat(
219222
activation_qat_config,
220223
weight_qat_config,
221224
)
225+
226+
227+
def convert_to_float8_training(
228+
model: nn.Module,
229+
fp8_recipe_name: Optional[str] = None,
230+
) -> nn.Module:
231+
"""
232+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
233+
234+
Args:
235+
model (nn.Module): The model to swap linear layers on
236+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
237+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
238+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
239+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
240+
for more details.
241+
242+
Returns:
243+
(nn.Module) The new model with `Float8Linear`.
244+
"""
245+
print(
246+
"doing fp8 quantized training, fp8_recipe_name = %s" % fp8_recipe_name or "N/A"
247+
)
248+
if fp8_recipe_name is not None:
249+
fp8_config = Float8LinearRecipeName(fp8_recipe_name)
250+
else:
251+
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
252+
return _convert_to_float8_training_torchao(
253+
model,
254+
config=fp8_config,
255+
module_filter_fn=lambda mod, fqn: fqn != "output",
256+
)

0 commit comments

Comments
 (0)