1
1
import torch
2
2
import unittest
3
- from torchao .testing .utils import copy_tests , TorchAOTensorParallelTestCase
4
3
from torch .testing ._internal .common_utils import run_tests
5
4
from torch .testing ._internal import common_utils
6
- from torchao .quantization import int8_weight_only , float8_weight_only , float8_dynamic_activation_float8_weight
5
+ from torchao .quantization import (
6
+ int4_weight_only ,
7
+ int8_weight_only ,
8
+ float8_weight_only ,
9
+ float8_dynamic_activation_float8_weight ,
10
+ )
7
11
from torchao .quantization .observer import PerRow , PerTensor
8
12
import torch .distributed as dist
9
13
from torch .distributed ._tensor import DTensor , Replicate , Shard , DeviceMesh
16
20
from torchao .dtypes import AffineQuantizedTensor
17
21
from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
18
22
19
- class TestInt8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
23
+ class TestAffineQuantizedTensorParallel (DTensorTestBase ):
24
+ """Basic test case for tensor subclasses
25
+ """
20
26
QUANT_METHOD_FN = staticmethod (int8_weight_only )
21
- copy_tests ( TorchAOTensorParallelTestCase , TestInt8woAffineQuantizedTensorParallel , "int8wo_tp" )
27
+ QUANT_METHOD_KWARGS = {}
22
28
23
- # Run only on H100
24
- if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
25
- class TestFloat8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
26
- QUANT_METHOD_FN = staticmethod (float8_weight_only )
27
- copy_tests (TorchAOTensorParallelTestCase , TestFloat8woAffineQuantizedTensorParallel , "fp8wo_tp" )
29
+ @staticmethod
30
+ def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
31
+ """
32
+ Shard linear layer of the model in column-wise fashion
33
+ """
34
+ # Column-wise is wrt to A^T, so for A it is row-wise.
35
+ # Number of rows per rank
36
+ orig_weight = m .linear .weight
37
+ n_local_rows = orig_weight .size (0 ) // mesh .size ()
38
+ rank = mesh .get_local_rank ()
39
+ local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
40
+ # Construct DTensor from local shard
41
+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
42
+ # Replace parameter in module
43
+ m .linear .weight = torch .nn .Parameter (
44
+ dtensor , requires_grad = False
45
+ )
46
+ return m
47
+
48
+ @staticmethod
49
+ def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
50
+ """
51
+ Shard linear layer of the model in row-wise fashion
52
+ """
53
+ # Row-wise is wrt to A^T, so for A it is column-wise.
54
+ # Number of rows per rank
55
+ orig_weight = m .linear .weight
56
+ n_local_cols = orig_weight .size (1 ) // mesh .size ()
57
+ rank = mesh .get_local_rank ()
58
+ local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
59
+ # Construct DTensor from local shard
60
+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
61
+ # Replace parameter in module
62
+ m .linear .weight = torch .nn .Parameter (
63
+ dtensor , requires_grad = False
64
+ )
65
+ return m
66
+
67
+ def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
68
+ """
69
+ Quantize the model
70
+ """
71
+ quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
72
+ return m
73
+
74
+ def _test_tp (self , dtype ):
75
+ device = "cuda"
76
+ # To make sure different ranks create the same module
77
+ torch .manual_seed (5 )
78
+
79
+ class M (torch .nn .Module ):
80
+ def __init__ (self , in_features , out_features , ** kwargs ) -> None :
81
+ super ().__init__ (** kwargs )
82
+ self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
83
+
84
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
85
+ return self .linear (x )
86
+
87
+ # Get rank and device
88
+ device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
89
+
90
+ # Original model
91
+ proj_up = M (1024 , 2048 ).to (device ).to (dtype )
92
+ proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
93
+ example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
94
+ y = proj_dn (proj_up (example_input ))
95
+ # Quantize the model
96
+ up_quant = self .quantize (proj_up )
97
+ dn_quant = self .quantize (proj_dn )
98
+ y_q = dn_quant (up_quant (example_input ))
99
+
100
+ mesh = self .build_device_mesh ()
101
+ mesh .device_type = "cuda"
102
+
103
+ # Shard the models
104
+ up_dist = self .colwise_shard (up_quant , mesh )
105
+ dn_dist = self .rowwise_shard (dn_quant , mesh )
106
+
107
+ # We need to turn inputs into DTensor form as well -- just a format change
108
+ input_dtensor = DTensor .from_local (
109
+ example_input , mesh , [Replicate ()]
110
+ )
111
+
112
+ y_d = dn_dist (up_dist (input_dtensor ))
113
+
114
+ if not TORCH_VERSION_AT_LEAST_2_5 :
115
+ # Need torch 2.5 to support compiled tensor parallelism
116
+ return
117
+
118
+ up_compiled = torch .compile (up_dist )
119
+ y_up = up_compiled (input_dtensor )
120
+ dn_compiled = torch .compile (dn_dist )
121
+ y_dn = dn_compiled (y_up )
122
+
123
+
124
+ class TestInt8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
125
+ QUANT_METHOD_FN = staticmethod (int8_weight_only )
126
+ COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
127
+
128
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
129
+ @with_comms
130
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
131
+ def test_tp (self , dtype ):
132
+ return self ._test_tp (dtype )
133
+
134
+
135
+ class TestInt4woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
136
+ QUANT_METHOD_FN = staticmethod (int4_weight_only )
137
+ COMMON_DTYPES = [torch .bfloat16 ]
138
+
139
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
140
+ @with_comms
141
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
142
+ def test_tp (self , dtype ):
143
+ return self ._test_tp (dtype )
144
+
145
+ common_utils .instantiate_parametrized_tests (TestInt8woAffineQuantizedTensorParallel )
146
+ common_utils .instantiate_parametrized_tests (TestInt4woAffineQuantizedTensorParallel )
28
147
29
148
# Run only on H100
30
149
if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
31
- class TestFloat8dqAffineQuantizedTensorParallel (DTensorTestBase ):
32
- """Basic test case for tensor subclasses
33
- """
150
+ class TestFloat8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
151
+ QUANT_METHOD_FN = staticmethod (float8_weight_only )
34
152
COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
35
- TENSOR_SUBCLASS = AffineQuantizedTensor
36
- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
37
- QUANT_METHOD_KWARGS = {}
38
-
39
- @staticmethod
40
- def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
41
- """
42
- Shard linear layer of the model in column-wise fashion
43
- """
44
- # Column-wise is wrt to A^T, so for A it is row-wise.
45
- # Number of rows per rank
46
- orig_weight = m .linear .weight
47
- n_local_rows = orig_weight .size (0 ) // mesh .size ()
48
- rank = mesh .get_local_rank ()
49
- local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
50
- # Construct DTensor from local shard
51
- dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
52
- # Replace parameter in module
53
- m .linear .weight = torch .nn .Parameter (
54
- dtensor , requires_grad = False
55
- )
56
- return m
57
-
58
- @staticmethod
59
- def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
60
- """
61
- Shard linear layer of the model in row-wise fashion
62
- """
63
- # Row-wise is wrt to A^T, so for A it is column-wise.
64
- # Number of rows per rank
65
- orig_weight = m .linear .weight
66
- n_local_cols = orig_weight .size (1 ) // mesh .size ()
67
- rank = mesh .get_local_rank ()
68
- local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
69
- # Construct DTensor from local shard
70
- dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
71
- # Replace parameter in module
72
- m .linear .weight = torch .nn .Parameter (
73
- dtensor , requires_grad = False
74
- )
75
- return m
76
-
77
- def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
78
- """
79
- Quantize the model
80
- """
81
- quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
82
- return m
83
-
84
- def _test_tp (self , dtype ):
85
- device = "cuda"
86
- # To make sure different ranks create the same module
87
- torch .manual_seed (5 )
88
-
89
- class M (torch .nn .Module ):
90
- def __init__ (self , in_features , out_features , ** kwargs ) -> None :
91
- super ().__init__ (** kwargs )
92
- self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
93
-
94
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
95
- return self .linear (x )
96
-
97
- # Get rank and device
98
- device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
99
-
100
- # Original model
101
- proj_up = M (1024 , 2048 ).to (device ).to (dtype )
102
- proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
103
- example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
104
- y = proj_dn (proj_up (example_input ))
105
- # Quantize the model
106
- up_quant = self .quantize (proj_up )
107
- dn_quant = self .quantize (proj_dn )
108
- y_q = dn_quant (up_quant (example_input ))
109
-
110
- mesh = self .build_device_mesh ()
111
- mesh .device_type = "cuda"
112
-
113
- # Shard the models
114
- up_dist = self .colwise_shard (up_quant , mesh )
115
- dn_dist = self .rowwise_shard (dn_quant , mesh )
116
-
117
- # We need to turn inputs into DTensor form as well -- just a format change
118
- input_dtensor = DTensor .from_local (
119
- example_input , mesh , [Replicate ()]
120
- )
121
-
122
- y_d = dn_dist (up_dist (input_dtensor ))
123
-
124
- if not TORCH_VERSION_AT_LEAST_2_5 :
125
- # Need torch 2.5 to support compiled tensor parallelism
126
- return
127
-
128
- up_compiled = torch .compile (up_dist )
129
- y_up = up_compiled (input_dtensor )
130
- dn_compiled = torch .compile (dn_dist )
131
- y_dn = dn_compiled (y_up )
132
-
133
- class TestFloat8dqTensorAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
153
+
154
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
155
+ @with_comms
156
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
157
+ def test_tp (self , dtype ):
158
+ return self ._test_tp (dtype )
159
+
160
+ class TestFloat8dqTensorAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
134
161
QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
135
162
QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
136
163
COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
@@ -141,7 +168,7 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantize
141
168
def test_tp (self , dtype ):
142
169
return self ._test_tp (dtype )
143
170
144
- class TestFloat8dqRowAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
171
+ class TestFloat8dqRowAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
145
172
QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
146
173
QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
147
174
COMMON_DTYPES = [torch .bfloat16 ]
@@ -151,7 +178,7 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTe
151
178
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
152
179
def test_tp (self , dtype ):
153
180
return self ._test_tp (dtype )
154
-
181
+
155
182
common_utils .instantiate_parametrized_tests (TestFloat8dqTensorAffineQuantizedTensorParallel )
156
183
common_utils .instantiate_parametrized_tests (TestFloat8dqRowAffineQuantizedTensorParallel )
157
184
if __name__ == "__main__" :
0 commit comments