Skip to content

Commit 4f8682d

Browse files
committed
Enable FP8 full finetune distributed
TODO: write this Based on #2404 by @nathan-az
1 parent f1ecdd6 commit 4f8682d

File tree

3 files changed

+102
-29
lines changed

3 files changed

+102
-29
lines changed

recipes/full_finetune_distributed.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -184,6 +185,8 @@ def __init__(self, cfg: DictConfig) -> None:
184185
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185186
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186187
self._checkpoint_client = CheckpointClient(cfg)
188+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
189+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
187190

188191
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
189192
if self._optimizer_in_bwd:
@@ -545,6 +548,11 @@ def _setup_model(
545548
if self._compile:
546549
training.compile_model(model, verbose=self._is_rank_zero)
547550

551+
if self._enable_fp8_training:
552+
# TODO: gate on nightlies?
553+
# TODO: validate self.tp_plan, if any, based on config
554+
model = convert_to_float8_training(model, self._fp8_recipe_name)
555+
548556
# Apply tensor parallelism to the model
549557
if self.parallel_dims.tp_enabled:
550558
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -846,6 +854,12 @@ def train(self) -> None:
846854
if self._lr_scheduler is not None:
847855
self._lr_scheduler.step()
848856

857+
# If float8 training is enabled, perform a single all-reduce to compute the
858+
# scale for all float8 parameters efficiently instead of doing many small
859+
# all-reduces for each parameter
860+
if self._enable_fp8_training and self.dp_degree > 1:
861+
precompute_float8_dynamic_scale_for_fsdp(self._model)
862+
849863
loss_to_log = running_loss.item() / num_tokens
850864
pbar.update(1)
851865
pbar.set_description(

torchtune/models/llama3/_parallelism.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch.distributed.tensor import Replicate, Shard
1010
from torch.distributed.tensor.parallel import (
@@ -16,31 +16,40 @@
1616
from torch.distributed.tensor.parallel.style import ParallelStyle
1717

1818

19-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20-
BASE_LLAMA_TP_PLAN = {
21-
"tok_embeddings": RowwiseParallel(
22-
input_layouts=Replicate(), output_layouts=Shard(1)
23-
),
24-
"norm": SequenceParallel(),
25-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
26-
"layers.*.attn": PrepareModuleInput(
27-
input_layouts=(Shard(1), None),
28-
desired_input_layouts=(Replicate(), None),
29-
),
30-
"layers.*.mlp": PrepareModuleInput(
31-
input_layouts=(Shard(1),),
32-
desired_input_layouts=(Replicate(),),
33-
),
34-
"layers.*.sa_norm": SequenceParallel(),
35-
"layers.*.mlp_norm": SequenceParallel(),
36-
"layers.*.attn.q_proj": ColwiseParallel(),
37-
"layers.*.attn.k_proj": ColwiseParallel(),
38-
"layers.*.attn.v_proj": ColwiseParallel(),
39-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
40-
"layers.*.mlp.w1": ColwiseParallel(),
41-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
42-
"layers.*.mlp.w3": ColwiseParallel(),
43-
}
19+
def _get_base_llama_tp_plan(
20+
_sequence_parallel_cls: Type[ParallelStyle] = SequenceParallel,
21+
_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
22+
_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
23+
) -> Dict[str, ParallelStyle]:
24+
"""
25+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
26+
"""
27+
return {
28+
"tok_embeddings": _rowwise_parallel_cls(
29+
input_layouts=Replicate(), output_layouts=Shard(1)
30+
),
31+
"norm": _sequence_parallel_cls(),
32+
"output": _colwise_parallel_cls(
33+
input_layouts=Shard(1), output_layouts=Replicate()
34+
),
35+
"layers.*.attn": PrepareModuleInput(
36+
input_layouts=(Shard(1), None),
37+
desired_input_layouts=(Replicate(), None),
38+
),
39+
"layers.*.mlp": PrepareModuleInput(
40+
input_layouts=(Shard(1),),
41+
desired_input_layouts=(Replicate(),),
42+
),
43+
"layers.*.sa_norm": _sequence_parallel_cls(),
44+
"layers.*.mlp_norm": _sequence_parallel_cls(),
45+
"layers.*.attn.q_proj": _colwise_parallel_cls(),
46+
"layers.*.attn.k_proj": _colwise_parallel_cls(),
47+
"layers.*.attn.v_proj": _colwise_parallel_cls(),
48+
"layers.*.attn.output_proj": _rowwise_parallel_cls(output_layouts=Shard(1)),
49+
"layers.*.mlp.w1": _colwise_parallel_cls(),
50+
"layers.*.mlp.w2": _rowwise_parallel_cls(output_layouts=Shard(1)),
51+
"layers.*.mlp.w3": _colwise_parallel_cls(),
52+
}
4453

4554

4655
def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
@@ -50,4 +59,19 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5059
Returns:
5160
Dict[str, Any]: The tensor parallel plan for Llama3 model.
5261
"""
53-
return BASE_LLAMA_TP_PLAN
62+
return _get_base_llama_tp_plan()
63+
64+
65+
def fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
66+
"""
67+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
68+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
69+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
70+
71+
Returns:
72+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
73+
"""
74+
return _get_base_llama_tp_plan(
75+
_colwise_parallel_cls=Float8ColwiseParallel,
76+
_rowwise_parallel_cls=Float8RowwiseParallel,
77+
)

torchtune/training/quantization.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from torch import nn
1010

1111
from torchao.dtypes import TensorCoreTiledLayout
12-
12+
from torchao.float8 import (
13+
convert_to_float8_training as _convert_to_float8_training_torchao,
14+
Float8LinearConfig,
15+
)
1316
from torchao.quantization import (
1417
int4_weight_only,
1518
int8_dynamic_activation_int4_weight,
1619
quantize_,
1720
)
18-
1921
from torchao.quantization.qat import (
2022
Int4WeightOnlyQATQuantizer,
2123
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +28,7 @@
2628
enable_4w_fake_quant,
2729
enable_8da4w_fake_quant,
2830
)
31+
2932
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3033

3134

@@ -219,3 +222,35 @@ def swap_lora_linear_with_qat(
219222
activation_qat_config,
220223
weight_qat_config,
221224
)
225+
226+
227+
def convert_to_float8_training(
228+
model: nn.Module,
229+
fp8_recipe_name: Optional[str] = None,
230+
) -> nn.Module:
231+
"""
232+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
233+
234+
Args:
235+
model (nn.Module): The model to swap linear layers on
236+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
237+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
238+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
239+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
240+
for more details.
241+
242+
Returns:
243+
(nn.Module) The new model with `Float8Linear`.
244+
"""
245+
print(
246+
"doing fp8 quantized training, fp8_recipe_name = %s" % fp8_recipe_name or "N/A"
247+
)
248+
if fp8_recipe_name is not None:
249+
fp8_config = Float8LinearRecipeName(fp8_recipe_name)
250+
else:
251+
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
252+
return _convert_to_float8_training_torchao(
253+
model,
254+
config=fp8_config,
255+
module_filter_fn=lambda mod, fqn: fqn != "output",
256+
)

0 commit comments

Comments
 (0)