1
1
import torch
2
+ import unittest
2
3
from torchao .testing .utils import copy_tests , TorchAOTensorParallelTestCase
3
4
from torch .testing ._internal .common_utils import run_tests
4
- from torchao .quantization import int8_weight_only , float8_weight_only
5
+ from torch .testing ._internal import common_utils
6
+ from torchao .quantization import int8_weight_only , float8_weight_only , float8_dynamic_activation_float8_weight
7
+ from torchao .quantization .observer import PerRow , PerTensor
8
+ import torch .distributed as dist
9
+ from torch .distributed ._tensor import DTensor , Replicate , Shard , DeviceMesh
10
+ from torch .testing ._internal .distributed ._tensor .common_dtensor import (
11
+ DTensorTestBase ,
12
+ with_comms ,
13
+ NUM_DEVICES ,
14
+ )
15
+ from torchao .quantization .quant_api import quantize_
16
+ from torchao .dtypes import AffineQuantizedTensor
17
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
5
18
6
19
class TestInt8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
7
20
QUANT_METHOD_FN = staticmethod (int8_weight_only )
@@ -13,5 +26,133 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
13
26
QUANT_METHOD_FN = staticmethod (float8_weight_only )
14
27
copy_tests (TorchAOTensorParallelTestCase , TestFloat8woAffineQuantizedTensorParallel , "fp8wo_tp" )
15
28
29
+ # Run only on H100
30
+ if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
31
+ class TestFloat8dqAffineQuantizedTensorParallel (DTensorTestBase ):
32
+ """Basic test case for tensor subclasses
33
+ """
34
+ COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
35
+ TENSOR_SUBCLASS = AffineQuantizedTensor
36
+ QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
37
+ QUANT_METHOD_KWARGS = {}
38
+
39
+ @staticmethod
40
+ def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
41
+ """
42
+ Shard linear layer of the model in column-wise fashion
43
+ """
44
+ # Column-wise is wrt to A^T, so for A it is row-wise.
45
+ # Number of rows per rank
46
+ orig_weight = m .linear .weight
47
+ n_local_rows = orig_weight .size (0 ) // mesh .size ()
48
+ rank = mesh .get_local_rank ()
49
+ local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
50
+ # Construct DTensor from local shard
51
+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
52
+ # Replace parameter in module
53
+ m .linear .weight = torch .nn .Parameter (
54
+ dtensor , requires_grad = False
55
+ )
56
+ return m
57
+
58
+ @staticmethod
59
+ def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
60
+ """
61
+ Shard linear layer of the model in row-wise fashion
62
+ """
63
+ # Row-wise is wrt to A^T, so for A it is column-wise.
64
+ # Number of rows per rank
65
+ orig_weight = m .linear .weight
66
+ n_local_cols = orig_weight .size (1 ) // mesh .size ()
67
+ rank = mesh .get_local_rank ()
68
+ local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
69
+ # Construct DTensor from local shard
70
+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
71
+ # Replace parameter in module
72
+ m .linear .weight = torch .nn .Parameter (
73
+ dtensor , requires_grad = False
74
+ )
75
+ return m
76
+
77
+ def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
78
+ """
79
+ Quantize the model
80
+ """
81
+ quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
82
+ return m
83
+
84
+ def _test_tp (self , dtype ):
85
+ device = "cuda"
86
+ # To make sure different ranks create the same module
87
+ torch .manual_seed (5 )
88
+
89
+ class M (torch .nn .Module ):
90
+ def __init__ (self , in_features , out_features , ** kwargs ) -> None :
91
+ super ().__init__ (** kwargs )
92
+ self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
93
+
94
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
95
+ return self .linear (x )
96
+
97
+ # Get rank and device
98
+ device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
99
+
100
+ # Original model
101
+ proj_up = M (1024 , 2048 ).to (device ).to (dtype )
102
+ proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
103
+ example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
104
+ y = proj_dn (proj_up (example_input ))
105
+ # Quantize the model
106
+ up_quant = self .quantize (proj_up )
107
+ dn_quant = self .quantize (proj_dn )
108
+ y_q = dn_quant (up_quant (example_input ))
109
+
110
+ mesh = self .build_device_mesh ()
111
+ mesh .device_type = "cuda"
112
+
113
+ # Shard the models
114
+ up_dist = self .colwise_shard (up_quant , mesh )
115
+ dn_dist = self .rowwise_shard (dn_quant , mesh )
116
+
117
+ # We need to turn inputs into DTensor form as well -- just a format change
118
+ input_dtensor = DTensor .from_local (
119
+ example_input , mesh , [Replicate ()]
120
+ )
121
+
122
+ y_d = dn_dist (up_dist (input_dtensor ))
123
+
124
+ if not TORCH_VERSION_AT_LEAST_2_5 :
125
+ # Need torch 2.5 to support compiled tensor parallelism
126
+ return
127
+
128
+ up_compiled = torch .compile (up_dist )
129
+ y_up = up_compiled (input_dtensor )
130
+ dn_compiled = torch .compile (dn_dist )
131
+ y_dn = dn_compiled (y_up )
132
+
133
+ class TestFloat8dqTensorAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
134
+ QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
135
+ QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
136
+ COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
137
+
138
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
139
+ @with_comms
140
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
141
+ def test_tp (self , dtype ):
142
+ return self ._test_tp (dtype )
143
+
144
+ class TestFloat8dqRowAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
145
+ QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
146
+ QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
147
+ COMMON_DTYPES = [torch .bfloat16 ]
148
+
149
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
150
+ @with_comms
151
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
152
+ def test_tp (self , dtype ):
153
+ return self ._test_tp (dtype )
154
+
155
+ common_utils .instantiate_parametrized_tests (TestFloat8dqTensorAffineQuantizedTensorParallel )
156
+ common_utils .instantiate_parametrized_tests (TestFloat8dqRowAffineQuantizedTensorParallel )
16
157
if __name__ == "__main__" :
17
158
run_tests ()
0 commit comments