Skip to content

Commit d1911a1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 99f10df commit d1911a1

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -262,44 +262,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
262262
scale: scale
263263
zero_point: zero point
264264
"""
265-
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
265+
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
266266
maxq = 2**num_bits - 1
267267
minq = 0
268-
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
269-
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
270-
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
271-
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
272-
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
273-
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
274-
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
275-
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
268+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
269+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
270+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
271+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
272+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
273+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
274+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
275+
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
276276
mask = rmin != rmax
277277
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
278278
scale = 1 / iscale
279-
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
280-
diff = scale * quant_data + rmin - data # (nb, group_size)
281-
best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
279+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
280+
diff = scale * quant_data + rmin - data # (nb, group_size)
281+
best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
282282
nstep = 20
283283
rdelta = 0.1
284284
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
285285
rrmin = -1
286286
for is_ in range(nstep):
287-
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
287+
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
288288
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
289289
mask = rmin != rmax
290290
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
291-
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
291+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
292292
mul_weights_quant_data_new = weights * quant_data_new
293-
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
294-
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
295-
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
296-
D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
293+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
294+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
295+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
296+
D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)
297297

298-
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
299-
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
298+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
299+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
300300

301-
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
302-
mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
301+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
302+
mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
303303

304304
mad_1 = np.array(mad)
305305
best_mad_1 = np.array(best_mad)

0 commit comments

Comments
 (0)