Skip to content

Commit 8aadb7d

Browse files
committed
Fix CI
1 parent bd3b79a commit 8aadb7d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
503503
n_bit = 4
504504
groupsize = 128
505505

506-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
506+
if TORCH_VERSION_AFTER_2_5:
507+
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
508+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
509+
else:
510+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
507511
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
508512

509513
self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

0 commit comments

Comments
 (0)