44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Dict
7+ from typing import Dict , Type
88
99from torch .distributed .tensor import Replicate , Shard
1010from torch .distributed .tensor .parallel import (
1616from torch .distributed .tensor .parallel .style import ParallelStyle
1717
1818
19- # Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20- BASE_LLAMA_TP_PLAN = {
21- "tok_embeddings" : RowwiseParallel (
22- input_layouts = Replicate (), output_layouts = Shard (1 )
23- ),
24- "norm" : SequenceParallel (),
25- "output" : ColwiseParallel (input_layouts = Shard (1 ), output_layouts = Replicate ()),
26- "layers.*.attn" : PrepareModuleInput (
27- input_layouts = (Shard (1 ), None ),
28- desired_input_layouts = (Replicate (), None ),
29- ),
30- "layers.*.mlp" : PrepareModuleInput (
31- input_layouts = (Shard (1 ),),
32- desired_input_layouts = (Replicate (),),
33- ),
34- "layers.*.sa_norm" : SequenceParallel (),
35- "layers.*.mlp_norm" : SequenceParallel (),
36- "layers.*.attn.q_proj" : ColwiseParallel (),
37- "layers.*.attn.k_proj" : ColwiseParallel (),
38- "layers.*.attn.v_proj" : ColwiseParallel (),
39- "layers.*.attn.output_proj" : RowwiseParallel (output_layouts = Shard (1 )),
40- "layers.*.mlp.w1" : ColwiseParallel (),
41- "layers.*.mlp.w2" : RowwiseParallel (output_layouts = Shard (1 )),
42- "layers.*.mlp.w3" : ColwiseParallel (),
43- }
19+ def _get_base_llama_tp_plan (
20+ _sequence_parallel_cls : Type [ParallelStyle ] = SequenceParallel ,
21+ _colwise_parallel_cls : Type [ParallelStyle ] = ColwiseParallel ,
22+ _rowwise_parallel_cls : Type [ParallelStyle ] = RowwiseParallel ,
23+ ) -> Dict [str , ParallelStyle ]:
24+ """
25+ Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
26+ """
27+ return {
28+ "tok_embeddings" : _rowwise_parallel_cls (
29+ input_layouts = Replicate (), output_layouts = Shard (1 )
30+ ),
31+ "norm" : _sequence_parallel_cls (),
32+ "output" : _colwise_parallel_cls (
33+ input_layouts = Shard (1 ), output_layouts = Replicate ()
34+ ),
35+ "layers.*.attn" : PrepareModuleInput (
36+ input_layouts = (Shard (1 ), None ),
37+ desired_input_layouts = (Replicate (), None ),
38+ ),
39+ "layers.*.mlp" : PrepareModuleInput (
40+ input_layouts = (Shard (1 ),),
41+ desired_input_layouts = (Replicate (),),
42+ ),
43+ "layers.*.sa_norm" : _sequence_parallel_cls (),
44+ "layers.*.mlp_norm" : _sequence_parallel_cls (),
45+ "layers.*.attn.q_proj" : _colwise_parallel_cls (),
46+ "layers.*.attn.k_proj" : _colwise_parallel_cls (),
47+ "layers.*.attn.v_proj" : _colwise_parallel_cls (),
48+ "layers.*.attn.output_proj" : _rowwise_parallel_cls (output_layouts = Shard (1 )),
49+ "layers.*.mlp.w1" : _colwise_parallel_cls (),
50+ "layers.*.mlp.w2" : _rowwise_parallel_cls (output_layouts = Shard (1 )),
51+ "layers.*.mlp.w3" : _colwise_parallel_cls (),
52+ }
4453
4554
4655def base_llama_tp_plan () -> Dict [str , ParallelStyle ]:
@@ -50,4 +59,19 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5059 Returns:
5160 Dict[str, Any]: The tensor parallel plan for Llama3 model.
5261 """
53- return BASE_LLAMA_TP_PLAN
62+ return _get_base_llama_tp_plan ()
63+
64+
65+ def fp8_llama_tp_plan () -> Dict [str , ParallelStyle ]:
66+ """
67+ Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
68+ rowwise and colwise computation, currently only compatible with float8 fine-tuning with
69+ "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
70+
71+ Returns:
72+ Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
73+ """
74+ return _get_base_llama_tp_plan (
75+ _colwise_parallel_cls = Float8ColwiseParallel ,
76+ _rowwise_parallel_cls = Float8RowwiseParallel ,
77+ )
0 commit comments