Skip to content

Commit babdd34

Browse files
committed
Merge remote-tracking branch 'origin/main' into aqt_refactor
2 parents 84b4d38 + 2ba1a61 commit babdd34

30 files changed

+347
-223
lines changed

ruff.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
55
include = [
66
"torchao/float8/**/*.py",
7-
"test/dtypes/test_nf4.py",
87
"torchao/quantization/**/*.py",
9-
"test/quantization/test_observer.py",
10-
"test/dtypes/test_affine_quantized_float.py",
118
"torchao/dtypes/**/*.py",
9+
"torchao/sparsity/**/*.py",
1210
"torchao/prototype/low_bit_optim/**.py",
11+
"test/quantization/test_observer.py",
12+
"test/dtypes/test_affine_quantized_float.py",
13+
"test/dtypes/test_nf4.py",
1314
"test/prototype/low_bit_optim/**.py",
1415
]
1516

torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
* This file is generated by gen_metal_shader_lib.py
2929
*/
3030
31-
#ifdef ATEN
31+
#ifdef USE_ATEN
3232
using namespace at::native::mps;
3333
#else
3434
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>

torchao/experimental/kernels/mps/src/lowbit.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <fstream>
1818
#include <sstream>
1919

20-
#ifdef ATEN
20+
#ifdef USE_ATEN
2121
#include <ATen/native/mps/OperationUtils.h>
2222
using namespace at::native::mps;
2323
inline void finalize_block(MPSStream* mpsStream) {}

torchao/experimental/ops/mps/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
name="torchao_mps_ops",
1717
sources=["register.mm"],
1818
include_dirs=[os.getenv("TORCHAO_ROOT")],
19-
extra_compile_args=["-DATEN=1"],
19+
extra_compile_args=["-DUSE_ATEN=1"],
2020
),
2121
],
2222
cmdclass={"build_ext": BuildExtension},

torchao/sparsity/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,23 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .wanda import WandaSparsifier # noqa: F403
8-
from .utils import PerChannelNormObserver # noqa: F403
7+
from torchao.quantization.quant_api import (
8+
int8_dynamic_activation_int8_semi_sparse_weight,
9+
)
10+
911
from .sparse_api import (
1012
apply_fake_sparsity,
11-
sparsify_,
1213
semi_sparse_weight,
13-
int8_dynamic_activation_int8_semi_sparse_weight
14+
sparsify_,
1415
)
16+
from .utils import PerChannelNormObserver # noqa: F403
17+
from .wanda import WandaSparsifier # noqa: F403
1518

1619
__all__ = [
1720
"WandaSparsifier",
1821
"PerChannelNormObserver",
1922
"apply_fake_sparsity",
20-
"sparsify_"
23+
"sparsify_",
2124
"semi_sparse_weight",
22-
"int8_dynamic_activation_int8_semi_sparse_weight"
25+
"int8_dynamic_activation_int8_semi_sparse_weight",
2326
]

torchao/sparsity/marlin/__init__.py

Lines changed: 76 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from typing import Tuple
2+
13
import torch
2-
from typing import Tuple, Dict, List
34

45
import torchao.sparsity.marlin.utils as utils
56
from torchao.sparsity.marlin.utils import const
67
from torchao.sparsity.utils import mask_creator
78

8-
99
__all__ = [
1010
"inject_24",
1111
"marlin_24_workspace",
@@ -14,11 +14,13 @@
1414
]
1515

1616

17-
def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor, torch.Tensor]:
17+
def inject_24(
18+
w: torch.Tensor, size_k: int, size_n: int
19+
) -> Tuple[torch.Tensor, torch.Tensor]:
1820
"""Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every
1921
group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the
2022
ranked weight values.
21-
23+
2224
Args:
2325
w (torch.Tensor): The weight tensor to inject sparsity into.
2426
size_k (int): The number of input features.
@@ -32,33 +34,35 @@ def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor,
3234

3335

3436
def marlin_24_workspace(
35-
out_features: int,
36-
min_thread_n: int = const.MIN_THREAD_N,
37-
max_parallel: int = const.MAX_PARALLEL
38-
) -> torch.Tensor:
39-
"""Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
37+
out_features: int,
38+
min_thread_n: int = const.MIN_THREAD_N,
39+
max_parallel: int = const.MAX_PARALLEL,
40+
) -> torch.Tensor:
41+
"""Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
4042
during the execution of the kernel.
41-
43+
4244
Args:
4345
out_features (int): The number of output features.
4446
min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`.
4547
max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_24_MAX_PARALLEL`.
4648
Returns:
4749
torch.Tensor: The workspace tensor fully initialized with zeros.
4850
"""
49-
assert (out_features % min_thread_n == 0), f"out_features = {out_features}, min_thread_n = {min_thread_n}"
50-
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
51+
assert (
52+
out_features % min_thread_n == 0
53+
), f"out_features = {out_features}, min_thread_n = {min_thread_n}"
54+
max_workspace_size = (out_features // min_thread_n) * max_parallel
5155
return torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
5256

5357

5458
def pack_to_marlin_24(
55-
q_w_24: torch.Tensor,
56-
scales: torch.Tensor,
57-
num_bits: int,
58-
group_size: int,
59-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59+
q_w_24: torch.Tensor,
60+
scales: torch.Tensor,
61+
num_bits: int,
62+
group_size: int,
63+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
6064
"""Packs the quantized weights and scales into the marlin 2:4 format.
61-
65+
6266
Args:
6367
q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied.
6468
scales (torch.Tensor): The scale tensor.
@@ -89,13 +93,13 @@ def pack_to_marlin_24(
8993

9094

9195
def unpack_from_marlin_24(
92-
q_w_24_comp: torch.Tensor,
93-
scales: torch.Tensor,
94-
meta: torch.Tensor,
95-
original_shape: torch.Size,
96-
group_size: int,
97-
num_bits: int
98-
) -> Tuple[torch.Tensor, torch.Tensor]:
96+
q_w_24_comp: torch.Tensor,
97+
scales: torch.Tensor,
98+
meta: torch.Tensor,
99+
original_shape: torch.Size,
100+
group_size: int,
101+
num_bits: int,
102+
) -> Tuple[torch.Tensor, torch.Tensor]:
99103
"""Unpacks the quantized weights and scales from the marlin 2:4 format.
100104
Args:
101105
q_w_24_comp (torch.Tensor): The packed quantized weights.
@@ -109,10 +113,8 @@ def unpack_from_marlin_24(
109113
"""
110114
in_features, out_features = original_shape
111115

112-
# Unpacks the scales
113-
unpacked_scales = _from_marlin_scale(
114-
scales, *original_shape, group_size, num_bits
115-
)
116+
# Unpacks the scales
117+
unpacked_scales = _from_marlin_scale(scales, *original_shape, group_size, num_bits)
116118

117119
in_features_comp = in_features // 2
118120

@@ -130,14 +132,11 @@ def unpack_from_marlin_24(
130132

131133

132134
def _compress_quantized_24_weight(
133-
q_24: torch.Tensor,
134-
size_k: int,
135-
size_n: int,
136-
num_bits: int
137-
) -> Tuple[torch.Tensor, torch.Tensor]:
138-
"""Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
135+
q_24: torch.Tensor, size_k: int, size_n: int, num_bits: int
136+
) -> Tuple[torch.Tensor, torch.Tensor]:
137+
"""Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
139138
before compressing them.
140-
139+
141140
Args:
142141
q_24 (torch.Tensor): The quantized weight tensor.
143142
size_k (int): The number of input features.
@@ -168,14 +167,10 @@ def _compress_quantized_24_weight(
168167

169168

170169
def _decompress_quantized_24_weight(
171-
q_24_comp: torch.Tensor,
172-
meta: torch.Tensor,
173-
size_k: int,
174-
size_n: int,
175-
num_bits: int
176-
) -> torch.Tensor:
170+
q_24_comp: torch.Tensor, meta: torch.Tensor, size_k: int, size_n: int, num_bits: int
171+
) -> torch.Tensor:
177172
"""Decompresses the quantized weights from a 2:4 sparse format and restores the original shape.
178-
173+
179174
Args:
180175
q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format.
181176
meta (torch.Tensor): The meta tensor.
@@ -210,13 +205,13 @@ def _decompress_quantized_24_weight(
210205

211206

212207
def _to_marlin_weights(
213-
q_w: torch.Tensor,
214-
size_k: int,
215-
size_n: int,
216-
num_bits: int,
217-
) -> torch.Tensor:
208+
q_w: torch.Tensor,
209+
size_k: int,
210+
size_n: int,
211+
num_bits: int,
212+
) -> torch.Tensor:
218213
"""Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format.
219-
214+
220215
Args:
221216
q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format.
222217
size_k (int): The number of input features.
@@ -236,7 +231,11 @@ def _to_marlin_weights(
236231
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
237232
# does not support rshift_cpu.
238233
q_w = q_w.cpu().to(torch.int64)
239-
q_packed = torch.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=torch.int64, device=q_w.device)
234+
q_packed = torch.zeros(
235+
(q_w.shape[0], q_w.shape[1] // pack_factor),
236+
dtype=torch.int64,
237+
device=q_w.device,
238+
)
240239
for i in range(pack_factor):
241240
q_packed |= q_w[:, i::pack_factor] << (num_bits * i)
242241

@@ -245,13 +244,10 @@ def _to_marlin_weights(
245244

246245

247246
def _from_marlin_weights(
248-
q_packed: torch.Tensor,
249-
size_k: int,
250-
size_n: int,
251-
num_bits: int
252-
) -> torch.Tensor:
247+
q_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
248+
) -> torch.Tensor:
253249
"""Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format.
254-
250+
255251
Args:
256252
q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format.
257253
size_k (int): The number of input features.
@@ -269,52 +265,54 @@ def _from_marlin_weights(
269265
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
270266
# does not support rshift_cpu.
271267
q_packed = q_packed.cpu().to(torch.int64)
272-
q_w_unpacked = torch.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=torch.int64, device=q_packed.device)
268+
q_w_unpacked = torch.zeros(
269+
(q_packed.shape[0], q_packed.shape[1] * pack_factor),
270+
dtype=torch.int64,
271+
device=q_packed.device,
272+
)
273273
for i in range(pack_factor):
274-
q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ((1 << num_bits) - 1)
274+
q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & (
275+
(1 << num_bits) - 1
276+
)
275277

276278
q_w_unpacked = q_w_unpacked.to(orig_device, dtype=torch.int32)
277279

278-
q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, perm_24)
280+
q_w_comp = utils.reverse_marlin_permute_weights(
281+
q_w_unpacked, size_k, size_n, perm_24
282+
)
279283
return q_w_comp
280284

281285

282286
def _to_marlin_scales(
283-
scales: torch.Tensor,
284-
size_k: int,
285-
size_n: int,
286-
group_size: int,
287-
num_bits: int
288-
) -> torch.Tensor:
287+
scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
288+
) -> torch.Tensor:
289289
"""Converts a scale tensor to the format necessary for marlin.
290290
Args:
291291
scales (torch.Tensor): The scale tensor.
292292
size_k (int): The number of input features.
293293
size_n (int): The number of output features.
294294
group_size (int): The group size that was applied during quantization.
295295
num_bits (int): The number of bits used for quantization.
296-
296+
297297
Returns:
298298
torch.Tensor: The scale tensor in the marlin format.
299299
"""
300300
_, scale_perm_24, scale_perm_single_24 = utils.get_perms_24(num_bits)
301301
if group_size < size_k and group_size != -1:
302302
scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24]
303303
else:
304-
scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24]
304+
scales = scales.reshape((-1, len(scale_perm_single_24)))[
305+
:, scale_perm_single_24
306+
]
305307
scales = scales.reshape((-1, size_n)).contiguous()
306308
return scales
307309

308310

309311
def _from_marlin_scale(
310-
scales: torch.Tensor,
311-
size_k: int,
312-
size_n: int,
313-
group_size: int,
314-
num_bits: int
315-
) -> torch.Tensor:
312+
scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
313+
) -> torch.Tensor:
316314
"""Converts a scale tensor from the marlin format to their original format.
317-
315+
318316
Args:
319317
scales (torch.Tensor): The scale tensor in the marlin format.
320318
size_k (int): The number of input features.
@@ -329,5 +327,7 @@ def _from_marlin_scale(
329327
scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24]
330328
return scales.reshape((size_k // group_size, size_n))
331329
else:
332-
scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24]
333-
return scales.reshape((1, -1))
330+
scales = scales.reshape((-1, len(scale_perm_single_24)))[
331+
:, scale_perm_single_24
332+
]
333+
return scales.reshape((1, -1))

0 commit comments

Comments
 (0)