Skip to content

Commit 20b08ee

Browse files
metal lowbit kernels: check contiguity of scales and zeros
Differential Revision: D65957327 Pull Request resolved: #1287
1 parent d4ca98f commit 20b08ee

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

torchao/experimental/kernels/mps/test/test_lowbit.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ void init() {
101101
int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize;
102102
for (int idx = 0; idx < N * ceil_K_group_size; ++idx) {
103103
s_ptr[idx] = (idx + 1.0) / N;
104-
z_ptr[idx] = int_distrib(generator);
104+
auto zp = int_distrib(generator);
105+
z_ptr[idx] = -zp * s_ptr[idx];
105106
}
106107
for (int idx = 0; idx < M * N; ++idx) {
107108
c_ptr[idx] = -1.0;

torchao/experimental/ops/mps/register.mm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ void check_linear_mps_args(
5858
": expect S to be 2d tensor with shape [:, ",
5959
N,
6060
"]");
61+
TORCH_CHECK(S.is_contiguous(), __func__, " : expect S to be contiguous.");
6162

6263
TORCH_CHECK(
6364
Z.dim() == 2 && Z.size(1) == N,
6465
__func__,
6566
": expect Z to be 2d tensor with shape [:, ",
6667
N,
6768
"]");
69+
TORCH_CHECK(Z.is_contiguous(), __func__, " : expect Z to be contiguous.");
6870
}
6971

7072
template <int nbit>

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,18 @@ class TestLowBitQuantWeightsLinear(unittest.TestCase):
4646
]
4747

4848
def _init_tensors(self, group_size, M, K, N, nbit, device="mps"):
49-
max_abs = 1 << (nbit - 1)
5049
ceil_K_group_size = (K + group_size - 1) // group_size
51-
A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1
52-
W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device)
50+
A = torch.rand(M, K, dtype=torch.float32, device=device)
51+
W = torch.randint(0, 1 << nbit, (N, K), dtype=torch.uint8, device=device)
5352
S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01
5453
Z = torch.randint(
5554
0,
56-
2 * max_abs,
55+
1 << nbit,
5756
(ceil_K_group_size, N),
5857
dtype=torch.float32,
5958
device=device,
6059
)
60+
Z = -Z * S
6161
return A, W, S, Z
6262

6363
def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):

0 commit comments

Comments
 (0)