Skip to content

Commit 1f7884d

Browse files
committed
remove print, add bias to asym quant tests
1 parent 86c2aeb commit 1f7884d

File tree

2 files changed

+2
-12
lines changed

2 files changed

+2
-12
lines changed

jetstream_pt/layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ def quantize_weight_from_nn_linear(self, weight):
246246
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size
247247
)
248248
w_dq = dequantize_tensor(w_q, scale, zp)
249-
print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq))
250249
self._load_quantized_weights(w_q, scale, zp)
251250

252251
def forward(self, inputs):

tests/test_quantization.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,9 @@ def quantize_dequantize_weight(w, n_bit):
143143
w_q_asym, s_asym, zp_asym = quantize_tensor(
144144
w, (1,), n_bit=n_bit, symmetric=False
145145
)
146-
# print(f"w_q_asym {w_q_asym}, s_asym {s_asym}, zp_asym {zp_asym}")
147146
w_dq_asym = dequantize_tensor(w_q_asym, s_asym, zp_asym)
148-
# print(f"w_dq_asym {w_dq_asym}")
149-
# self._print_diff(w, w_dq)
150-
# self._print_diff(w, w_dq_asym)
151147
# Asymmetric is more accurate than symmetric.
152-
print((w - w_dq_asym))
153-
print((w - w_dq))
154-
self.assertLess(
155-
(w - w_dq_asym).to(torch.float32).norm(),
156-
(w - w_dq).to(torch.float32).norm(),
157-
)
148+
self.assertLess((w - w_dq_asym).norm(), (w - w_dq).norm(),)
158149
# Blockwise quant.
159150
w_block_q, s_block, _ = quantize_tensor(
160151
w, (1,), n_bit=n_bit, symmetric=True, block_size=2
@@ -174,7 +165,7 @@ def quantize_dequantize_weight(w, n_bit):
174165
# Blockwise asymmetric is more accurate than blockwise symmetric.
175166
self.assertLess((w - w_block_asym_dq).norm(), (w - w_block_dq).norm())
176167

177-
w = torch.randn(2, 8)
168+
w = torch.randn(2, 8) + 2 # Add a bias to normal dist to test asymmetric quant.
178169
for bit in [4, 8]:
179170
with self.subTest(bit=bit):
180171
quantize_dequantize_weight(w, bit)

0 commit comments

Comments
 (0)