44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Dict
7+ from typing import Dict , Type
88
99from torch .distributed .tensor import Replicate , Shard
1010from torch .distributed .tensor .parallel import (
1515)
1616from torch .distributed .tensor .parallel .style import ParallelStyle
1717
18+ from torchao .float8 .float8_tensor_parallel import (
19+ Float8ColwiseParallel ,
20+ Float8RowwiseParallel ,
21+ )
1822
19- # Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20- BASE_LLAMA_TP_PLAN = {
21- "tok_embeddings" : RowwiseParallel (
22- input_layouts = Replicate (), output_layouts = Shard (1 )
23- ),
24- "norm" : SequenceParallel (),
25- "output" : ColwiseParallel (input_layouts = Shard (1 ), output_layouts = Replicate ()),
26- "layers.*.attn" : PrepareModuleInput (
27- input_layouts = (Shard (1 ), None ),
28- desired_input_layouts = (Replicate (), None ),
29- ),
30- "layers.*.mlp" : PrepareModuleInput (
31- input_layouts = (Shard (1 ),),
32- desired_input_layouts = (Replicate (),),
33- ),
34- "layers.*.sa_norm" : SequenceParallel (),
35- "layers.*.mlp_norm" : SequenceParallel (),
36- "layers.*.attn.q_proj" : ColwiseParallel (),
37- "layers.*.attn.k_proj" : ColwiseParallel (),
38- "layers.*.attn.v_proj" : ColwiseParallel (),
39- "layers.*.attn.output_proj" : RowwiseParallel (output_layouts = Shard (1 )),
40- "layers.*.mlp.w1" : ColwiseParallel (),
41- "layers.*.mlp.w2" : RowwiseParallel (output_layouts = Shard (1 )),
42- "layers.*.mlp.w3" : ColwiseParallel (),
43- }
23+
24+ def _get_base_llama_tp_plan (
25+ _sequence_parallel_cls : Type [ParallelStyle ] = SequenceParallel ,
26+ _colwise_parallel_cls : Type [ParallelStyle ] = ColwiseParallel ,
27+ _rowwise_parallel_cls : Type [ParallelStyle ] = RowwiseParallel ,
28+ ) -> Dict [str , ParallelStyle ]:
29+ """
30+ Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
31+ """
32+ return {
33+ "tok_embeddings" : _rowwise_parallel_cls (
34+ input_layouts = Replicate (), output_layouts = Shard (1 )
35+ ),
36+ "norm" : _sequence_parallel_cls (),
37+ "output" : _colwise_parallel_cls (
38+ input_layouts = Shard (1 ), output_layouts = Replicate ()
39+ ),
40+ "layers.*.attn" : PrepareModuleInput (
41+ input_layouts = (Shard (1 ), None ),
42+ desired_input_layouts = (Replicate (), None ),
43+ ),
44+ "layers.*.mlp" : PrepareModuleInput (
45+ input_layouts = (Shard (1 ),),
46+ desired_input_layouts = (Replicate (),),
47+ ),
48+ "layers.*.sa_norm" : _sequence_parallel_cls (),
49+ "layers.*.mlp_norm" : _sequence_parallel_cls (),
50+ "layers.*.attn.q_proj" : _colwise_parallel_cls (),
51+ "layers.*.attn.k_proj" : _colwise_parallel_cls (),
52+ "layers.*.attn.v_proj" : _colwise_parallel_cls (),
53+ "layers.*.attn.output_proj" : _rowwise_parallel_cls (output_layouts = Shard (1 )),
54+ "layers.*.mlp.w1" : _colwise_parallel_cls (),
55+ "layers.*.mlp.w2" : _rowwise_parallel_cls (output_layouts = Shard (1 )),
56+ "layers.*.mlp.w3" : _colwise_parallel_cls (),
57+ }
4458
4559
4660def base_llama_tp_plan () -> Dict [str , ParallelStyle ]:
@@ -50,4 +64,19 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5064 Returns:
5165 Dict[str, Any]: The tensor parallel plan for Llama3 model.
5266 """
53- return BASE_LLAMA_TP_PLAN
67+ return _get_base_llama_tp_plan ()
68+
69+
70+ def fp8_llama_tp_plan () -> Dict [str , ParallelStyle ]:
71+ """
72+ Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
73+ rowwise and colwise computation, currently only compatible with float8 fine-tuning with
74+ "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
75+
76+ Returns:
77+ Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
78+ """
79+ return _get_base_llama_tp_plan (
80+ _colwise_parallel_cls = Float8ColwiseParallel ,
81+ _rowwise_parallel_cls = Float8RowwiseParallel ,
82+ )
0 commit comments