1
+ from typing import Tuple
2
+
1
3
import torch
2
- from typing import Tuple , Dict , List
3
4
4
5
import torchao .sparsity .marlin .utils as utils
5
6
from torchao .sparsity .marlin .utils import const
6
7
from torchao .sparsity .utils import mask_creator
7
8
8
-
9
9
__all__ = [
10
10
"inject_24" ,
11
11
"marlin_24_workspace" ,
14
14
]
15
15
16
16
17
- def inject_24 (w : torch .Tensor , size_k : int , size_n : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
17
+ def inject_24 (
18
+ w : torch .Tensor , size_k : int , size_n : int
19
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
18
20
"""Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every
19
21
group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the
20
22
ranked weight values.
21
-
23
+
22
24
Args:
23
25
w (torch.Tensor): The weight tensor to inject sparsity into.
24
26
size_k (int): The number of input features.
@@ -32,33 +34,35 @@ def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor,
32
34
33
35
34
36
def marlin_24_workspace (
35
- out_features : int ,
36
- min_thread_n : int = const .MIN_THREAD_N ,
37
- max_parallel : int = const .MAX_PARALLEL
38
- ) -> torch .Tensor :
39
- """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
37
+ out_features : int ,
38
+ min_thread_n : int = const .MIN_THREAD_N ,
39
+ max_parallel : int = const .MAX_PARALLEL ,
40
+ ) -> torch .Tensor :
41
+ """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
40
42
during the execution of the kernel.
41
-
43
+
42
44
Args:
43
45
out_features (int): The number of output features.
44
46
min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`.
45
47
max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_24_MAX_PARALLEL`.
46
48
Returns:
47
49
torch.Tensor: The workspace tensor fully initialized with zeros.
48
50
"""
49
- assert (out_features % min_thread_n == 0 ), f"out_features = { out_features } , min_thread_n = { min_thread_n } "
50
- max_workspace_size = ((out_features // min_thread_n ) * max_parallel )
51
+ assert (
52
+ out_features % min_thread_n == 0
53
+ ), f"out_features = { out_features } , min_thread_n = { min_thread_n } "
54
+ max_workspace_size = (out_features // min_thread_n ) * max_parallel
51
55
return torch .zeros (max_workspace_size , dtype = torch .int , device = "cuda" )
52
56
53
57
54
58
def pack_to_marlin_24 (
55
- q_w_24 : torch .Tensor ,
56
- scales : torch .Tensor ,
57
- num_bits : int ,
58
- group_size : int ,
59
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
59
+ q_w_24 : torch .Tensor ,
60
+ scales : torch .Tensor ,
61
+ num_bits : int ,
62
+ group_size : int ,
63
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
60
64
"""Packs the quantized weights and scales into the marlin 2:4 format.
61
-
65
+
62
66
Args:
63
67
q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied.
64
68
scales (torch.Tensor): The scale tensor.
@@ -89,13 +93,13 @@ def pack_to_marlin_24(
89
93
90
94
91
95
def unpack_from_marlin_24 (
92
- q_w_24_comp : torch .Tensor ,
93
- scales : torch .Tensor ,
94
- meta : torch .Tensor ,
95
- original_shape : torch .Size ,
96
- group_size : int ,
97
- num_bits : int
98
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
96
+ q_w_24_comp : torch .Tensor ,
97
+ scales : torch .Tensor ,
98
+ meta : torch .Tensor ,
99
+ original_shape : torch .Size ,
100
+ group_size : int ,
101
+ num_bits : int ,
102
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
99
103
"""Unpacks the quantized weights and scales from the marlin 2:4 format.
100
104
Args:
101
105
q_w_24_comp (torch.Tensor): The packed quantized weights.
@@ -109,10 +113,8 @@ def unpack_from_marlin_24(
109
113
"""
110
114
in_features , out_features = original_shape
111
115
112
- # Unpacks the scales
113
- unpacked_scales = _from_marlin_scale (
114
- scales , * original_shape , group_size , num_bits
115
- )
116
+ # Unpacks the scales
117
+ unpacked_scales = _from_marlin_scale (scales , * original_shape , group_size , num_bits )
116
118
117
119
in_features_comp = in_features // 2
118
120
@@ -130,14 +132,11 @@ def unpack_from_marlin_24(
130
132
131
133
132
134
def _compress_quantized_24_weight (
133
- q_24 : torch .Tensor ,
134
- size_k : int ,
135
- size_n : int ,
136
- num_bits : int
137
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
138
- """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
135
+ q_24 : torch .Tensor , size_k : int , size_n : int , num_bits : int
136
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
137
+ """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
139
138
before compressing them.
140
-
139
+
141
140
Args:
142
141
q_24 (torch.Tensor): The quantized weight tensor.
143
142
size_k (int): The number of input features.
@@ -168,14 +167,10 @@ def _compress_quantized_24_weight(
168
167
169
168
170
169
def _decompress_quantized_24_weight (
171
- q_24_comp : torch .Tensor ,
172
- meta : torch .Tensor ,
173
- size_k : int ,
174
- size_n : int ,
175
- num_bits : int
176
- ) -> torch .Tensor :
170
+ q_24_comp : torch .Tensor , meta : torch .Tensor , size_k : int , size_n : int , num_bits : int
171
+ ) -> torch .Tensor :
177
172
"""Decompresses the quantized weights from a 2:4 sparse format and restores the original shape.
178
-
173
+
179
174
Args:
180
175
q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format.
181
176
meta (torch.Tensor): The meta tensor.
@@ -210,13 +205,13 @@ def _decompress_quantized_24_weight(
210
205
211
206
212
207
def _to_marlin_weights (
213
- q_w : torch .Tensor ,
214
- size_k : int ,
215
- size_n : int ,
216
- num_bits : int ,
217
- ) -> torch .Tensor :
208
+ q_w : torch .Tensor ,
209
+ size_k : int ,
210
+ size_n : int ,
211
+ num_bits : int ,
212
+ ) -> torch .Tensor :
218
213
"""Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format.
219
-
214
+
220
215
Args:
221
216
q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format.
222
217
size_k (int): The number of input features.
@@ -236,7 +231,11 @@ def _to_marlin_weights(
236
231
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
237
232
# does not support rshift_cpu.
238
233
q_w = q_w .cpu ().to (torch .int64 )
239
- q_packed = torch .zeros ((q_w .shape [0 ], q_w .shape [1 ] // pack_factor ), dtype = torch .int64 , device = q_w .device )
234
+ q_packed = torch .zeros (
235
+ (q_w .shape [0 ], q_w .shape [1 ] // pack_factor ),
236
+ dtype = torch .int64 ,
237
+ device = q_w .device ,
238
+ )
240
239
for i in range (pack_factor ):
241
240
q_packed |= q_w [:, i ::pack_factor ] << (num_bits * i )
242
241
@@ -245,13 +244,10 @@ def _to_marlin_weights(
245
244
246
245
247
246
def _from_marlin_weights (
248
- q_packed : torch .Tensor ,
249
- size_k : int ,
250
- size_n : int ,
251
- num_bits : int
252
- ) -> torch .Tensor :
247
+ q_packed : torch .Tensor , size_k : int , size_n : int , num_bits : int
248
+ ) -> torch .Tensor :
253
249
"""Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format.
254
-
250
+
255
251
Args:
256
252
q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format.
257
253
size_k (int): The number of input features.
@@ -269,52 +265,54 @@ def _from_marlin_weights(
269
265
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
270
266
# does not support rshift_cpu.
271
267
q_packed = q_packed .cpu ().to (torch .int64 )
272
- q_w_unpacked = torch .zeros ((q_packed .shape [0 ], q_packed .shape [1 ] * pack_factor ), dtype = torch .int64 , device = q_packed .device )
268
+ q_w_unpacked = torch .zeros (
269
+ (q_packed .shape [0 ], q_packed .shape [1 ] * pack_factor ),
270
+ dtype = torch .int64 ,
271
+ device = q_packed .device ,
272
+ )
273
273
for i in range (pack_factor ):
274
- q_w_unpacked [:, i ::pack_factor ] = (q_packed >> (num_bits * i )) & ((1 << num_bits ) - 1 )
274
+ q_w_unpacked [:, i ::pack_factor ] = (q_packed >> (num_bits * i )) & (
275
+ (1 << num_bits ) - 1
276
+ )
275
277
276
278
q_w_unpacked = q_w_unpacked .to (orig_device , dtype = torch .int32 )
277
279
278
- q_w_comp = utils .reverse_marlin_permute_weights (q_w_unpacked , size_k , size_n , perm_24 )
280
+ q_w_comp = utils .reverse_marlin_permute_weights (
281
+ q_w_unpacked , size_k , size_n , perm_24
282
+ )
279
283
return q_w_comp
280
284
281
285
282
286
def _to_marlin_scales (
283
- scales : torch .Tensor ,
284
- size_k : int ,
285
- size_n : int ,
286
- group_size : int ,
287
- num_bits : int
288
- ) -> torch .Tensor :
287
+ scales : torch .Tensor , size_k : int , size_n : int , group_size : int , num_bits : int
288
+ ) -> torch .Tensor :
289
289
"""Converts a scale tensor to the format necessary for marlin.
290
290
Args:
291
291
scales (torch.Tensor): The scale tensor.
292
292
size_k (int): The number of input features.
293
293
size_n (int): The number of output features.
294
294
group_size (int): The group size that was applied during quantization.
295
295
num_bits (int): The number of bits used for quantization.
296
-
296
+
297
297
Returns:
298
298
torch.Tensor: The scale tensor in the marlin format.
299
299
"""
300
300
_ , scale_perm_24 , scale_perm_single_24 = utils .get_perms_24 (num_bits )
301
301
if group_size < size_k and group_size != - 1 :
302
302
scales = scales .reshape ((- 1 , len (scale_perm_24 )))[:, scale_perm_24 ]
303
303
else :
304
- scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[:, scale_perm_single_24 ]
304
+ scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[
305
+ :, scale_perm_single_24
306
+ ]
305
307
scales = scales .reshape ((- 1 , size_n )).contiguous ()
306
308
return scales
307
309
308
310
309
311
def _from_marlin_scale (
310
- scales : torch .Tensor ,
311
- size_k : int ,
312
- size_n : int ,
313
- group_size : int ,
314
- num_bits : int
315
- ) -> torch .Tensor :
312
+ scales : torch .Tensor , size_k : int , size_n : int , group_size : int , num_bits : int
313
+ ) -> torch .Tensor :
316
314
"""Converts a scale tensor from the marlin format to their original format.
317
-
315
+
318
316
Args:
319
317
scales (torch.Tensor): The scale tensor in the marlin format.
320
318
size_k (int): The number of input features.
@@ -329,5 +327,7 @@ def _from_marlin_scale(
329
327
scales = scales .reshape ((- 1 , len (scale_perm_24 )))[:, scale_perm_24 ]
330
328
return scales .reshape ((size_k // group_size , size_n ))
331
329
else :
332
- scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[:, scale_perm_single_24 ]
333
- return scales .reshape ((1 , - 1 ))
330
+ scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[
331
+ :, scale_perm_single_24
332
+ ]
333
+ return scales .reshape ((1 , - 1 ))
0 commit comments