Skip to content

Commit df526dd

Browse files
committed
Enable FP8 full finetune distributed
TODO: write this Based on #2404 by @nathan-az
1 parent f1ecdd6 commit df526dd

File tree

4 files changed

+157
-31
lines changed

4 files changed

+157
-31
lines changed

recipes/full_finetune_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -33,6 +34,11 @@
3334
TrainingProgress,
3435
)
3536
from torchtune.training.lr_schedulers import get_lr
37+
from torchtune.training.quantization import (
38+
convert_to_float8_training,
39+
is_fp8_tensorwise_scaling,
40+
validate_float8_tp_plan,
41+
)
3642

3743
from tqdm import tqdm
3844

@@ -184,6 +190,8 @@ def __init__(self, cfg: DictConfig) -> None:
184190
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185191
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186192
self._checkpoint_client = CheckpointClient(cfg)
193+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
194+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
187195

188196
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
189197
if self._optimizer_in_bwd:
@@ -545,6 +553,15 @@ def _setup_model(
545553
if self._compile:
546554
training.compile_model(model, verbose=self._is_rank_zero)
547555

556+
if self._enable_fp8_training:
557+
# Requires https://github.com/pytorch/pytorch/pull/148922
558+
if torch.__version__ < "2.8.0.dev20250318":
559+
raise RuntimeError(
560+
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
561+
)
562+
validate_float8_tp_plan(self.tp_plan, self._fp8_recipe_name)
563+
model = convert_to_float8_training(model, self._fp8_recipe_name)
564+
548565
# Apply tensor parallelism to the model
549566
if self.parallel_dims.tp_enabled:
550567
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -846,6 +863,16 @@ def train(self) -> None:
846863
if self._lr_scheduler is not None:
847864
self._lr_scheduler.step()
848865

866+
# If float8 training is enabled, perform a single all-reduce to compute the
867+
# scale for all float8 parameters efficiently instead of doing many small
868+
# all-reduces for each parameter
869+
if (
870+
self._enable_fp8_training
871+
and is_fp8_tensorwise_scaling(self._fp8_recipe_name)
872+
and self.dp_degree > 1
873+
):
874+
precompute_float8_dynamic_scale_for_fsdp(self._model)
875+
849876
loss_to_log = running_loss.item() / num_tokens
850877
pbar.update(1)
851878
pbar.set_description(

torchtune/models/llama3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
qlora_llama3_70b,
1616
qlora_llama3_8b,
1717
)
18-
from ._parallelism import base_llama_tp_plan
18+
from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan
1919
from ._tokenizer import Llama3Tokenizer
2020

2121
__all__ = [
@@ -30,4 +30,5 @@
3030
"qlora_llama3_8b",
3131
"qlora_llama3_70b",
3232
"base_llama_tp_plan",
33+
"fp8_llama_tp_plan",
3334
]

torchtune/models/llama3/_parallelism.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch.distributed.tensor import Replicate, Shard
1010
from torch.distributed.tensor.parallel import (
@@ -15,32 +15,53 @@
1515
)
1616
from torch.distributed.tensor.parallel.style import ParallelStyle
1717

18+
from torchao.float8.float8_tensor_parallel import (
19+
Float8ColwiseParallel,
20+
Float8RowwiseParallel,
21+
)
22+
1823

19-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20-
BASE_LLAMA_TP_PLAN = {
21-
"tok_embeddings": RowwiseParallel(
22-
input_layouts=Replicate(), output_layouts=Shard(1)
23-
),
24-
"norm": SequenceParallel(),
25-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
26-
"layers.*.attn": PrepareModuleInput(
27-
input_layouts=(Shard(1), None),
28-
desired_input_layouts=(Replicate(), None),
29-
),
30-
"layers.*.mlp": PrepareModuleInput(
31-
input_layouts=(Shard(1),),
32-
desired_input_layouts=(Replicate(),),
33-
),
34-
"layers.*.sa_norm": SequenceParallel(),
35-
"layers.*.mlp_norm": SequenceParallel(),
36-
"layers.*.attn.q_proj": ColwiseParallel(),
37-
"layers.*.attn.k_proj": ColwiseParallel(),
38-
"layers.*.attn.v_proj": ColwiseParallel(),
39-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
40-
"layers.*.mlp.w1": ColwiseParallel(),
41-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
42-
"layers.*.mlp.w3": ColwiseParallel(),
43-
}
24+
def _get_base_llama_tp_plan(
25+
_sequence_parallel_cls: Type[ParallelStyle] = SequenceParallel,
26+
_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
27+
_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
28+
) -> Dict[str, ParallelStyle]:
29+
"""
30+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
31+
"""
32+
return {
33+
"tok_embeddings": _rowwise_parallel_cls(
34+
input_layouts=Replicate(), output_layouts=Shard(1)
35+
),
36+
"norm": _sequence_parallel_cls(),
37+
"output": _colwise_parallel_cls(
38+
input_layouts=Shard(1), output_layouts=Replicate()
39+
),
40+
"layers.*.attn": PrepareModuleInput(
41+
input_layouts=(Shard(1), None),
42+
desired_input_layouts=(Replicate(), None),
43+
),
44+
"layers.*.mlp": PrepareModuleInput(
45+
input_layouts=(Shard(1),),
46+
desired_input_layouts=(Replicate(),),
47+
),
48+
"layers.*.sa_norm": _sequence_parallel_cls(),
49+
"layers.*.mlp_norm": _sequence_parallel_cls(),
50+
"layers.*.attn.q_proj": _colwise_parallel_cls(),
51+
"layers.*.attn.k_proj": _colwise_parallel_cls(),
52+
"layers.*.attn.v_proj": _colwise_parallel_cls(),
53+
"layers.*.attn.output_proj": _rowwise_parallel_cls(output_layouts=Shard(1)),
54+
"layers.*.mlp.w1": _colwise_parallel_cls(),
55+
"layers.*.mlp.w2": _rowwise_parallel_cls(output_layouts=Shard(1)),
56+
"layers.*.mlp.w3": _colwise_parallel_cls(),
57+
}
58+
59+
60+
_BASE_LLAMA_TP_PLAN = _get_base_llama_tp_plan()
61+
_FP8_LLAMA_TP_PLAN = _get_base_llama_tp_plan(
62+
_colwise_parallel_cls=Float8ColwiseParallel,
63+
_rowwise_parallel_cls=Float8RowwiseParallel,
64+
)
4465

4566

4667
def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
@@ -50,4 +71,16 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5071
Returns:
5172
Dict[str, Any]: The tensor parallel plan for Llama3 model.
5273
"""
53-
return BASE_LLAMA_TP_PLAN
74+
return _BASE_LLAMA_TP_PLAN
75+
76+
77+
def fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
78+
"""
79+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
80+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
81+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
82+
83+
Returns:
84+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
85+
"""
86+
return _FP8_LLAMA_TP_PLAN

torchtune/training/quantization.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Optional
7+
from typing import Callable, Dict, Optional
88

99
from torch import nn
10+
from torch.distributed.tensor.parallel.style import ParallelStyle
1011

1112
from torchao.dtypes import TensorCoreTiledLayout
12-
13+
from torchao.float8 import (
14+
convert_to_float8_training as _convert_to_float8_training_torchao,
15+
Float8LinearConfig,
16+
)
17+
from torchao.float8.float8_tensor_parallel import (
18+
Float8ColwiseParallel,
19+
Float8RowwiseParallel,
20+
)
1321
from torchao.quantization import (
1422
int4_weight_only,
1523
int8_dynamic_activation_int4_weight,
1624
quantize_,
1725
)
18-
1926
from torchao.quantization.qat import (
2027
Int4WeightOnlyQATQuantizer,
2128
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +33,7 @@
2633
enable_4w_fake_quant,
2734
enable_8da4w_fake_quant,
2835
)
36+
2937
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3038

3139

@@ -219,3 +227,60 @@ def swap_lora_linear_with_qat(
219227
activation_qat_config,
220228
weight_qat_config,
221229
)
230+
231+
232+
def convert_to_float8_training(
233+
model: nn.Module,
234+
fp8_recipe_name: Optional[str] = None,
235+
) -> nn.Module:
236+
"""
237+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
238+
239+
Args:
240+
model (nn.Module): The model to swap linear layers on
241+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
242+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
243+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
244+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
245+
for more details.
246+
247+
Returns:
248+
(nn.Module) The new model with `Float8Linear`.
249+
"""
250+
if fp8_recipe_name is not None:
251+
fp8_config = Float8LinearRecipeName(fp8_recipe_name)
252+
else:
253+
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
254+
return _convert_to_float8_training_torchao(
255+
model,
256+
config=fp8_config,
257+
module_filter_fn=lambda mod, fqn: fqn != "output",
258+
)
259+
260+
261+
def validate_float8_tp_plan(
262+
tp_plan: Optional[Dict[str, ParallelStyle]],
263+
fp8_recipe_name: Optional[str] = None,
264+
) -> None:
265+
"""
266+
Validate that the provided tensor parallel plan is compatible with the
267+
float8 settings. Specifically, float8 tensor parallel plans are only
268+
supported when using 'tensorwise' float8 recipes.
269+
"""
270+
if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name):
271+
return
272+
for parallel_style in tp_plan.values():
273+
if isinstance(parallel_style, Float8ColwiseParallel) or isinstance(
274+
parallel_style, Float8RowwiseParallel
275+
):
276+
raise ValueError(
277+
"%s and %s are only compatible with 'tensorwise' float8 recipes"
278+
% (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__)
279+
)
280+
281+
282+
def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]):
283+
"""
284+
Return True if the fp8 recipe name refers to 'tensorwwise' scaling.
285+
"""
286+
return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"

0 commit comments

Comments
 (0)