Skip to content

k quant #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

k quant #2169

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 189 additions & 11 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@
ONNXRT1161_VERSION = Version("1.16.1")


def get_blob_size(group_size, has_zp): # pragma: no cover
def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover
"""Get blob_size.

Args:
group_size (int): how many elements share one scale/zp
has_zp (bool): whether zero_point is None
"""
if Version(ort.__version__) > ONNXRT1161_VERSION:
blob_size = group_size // 2
blob_size = group_size * num_bits // 8
elif has_zp:
blob_size = group_size // 2 + 4 + 1
blob_size = group_size * num_bits // 8 + 4 + 1
else:
blob_size = group_size // 2 + 4
blob_size = group_size * num_bits // 8 + 4
return blob_size


Expand Down Expand Up @@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
new_inits: initializers of the new node
"""
blob_size = get_blob_size(group_size, zero_point is not None)
blob_size = get_blob_size(group_size, num_bits, zero_point is not None)
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
input_names = [node.input[0], q_weight_name]
Expand All @@ -97,8 +97,14 @@ def make_matmul_weight_only_node(
op_type = "MatMulNBits"

# pack quantized weight
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
packed[:, :] = q_weight_pairs[:, :blob_size]
if num_bits == 4:
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
packed[:, :] = q_weight_pairs[:, :blob_size]
elif num_bits == 8:
packed = q_weight
else:
logger.error("MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits))

packed = np.reshape(packed, (-1, k_blocks, blob_size))

# build scale tensor
Expand Down Expand Up @@ -247,6 +253,170 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
return q_weight, scale, zero_point


def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
"""Quantize tensor per group based on k quant.

Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c

Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.

Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
maxq = 2**num_bits - 1
minq = 0
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
mask = rmin != rmax
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
scale = 1 / iscale
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
diff = scale * quant_data + rmin - data # (nb, group_size)
best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
nstep = 20
rdelta = 0.1
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
rrmin = -1
for is_ in range(nstep):
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
mask = rmin != rmax
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
mul_weights_quant_data_new = weights * quant_data_new
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)

this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)

diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)

mad_1 = np.array(mad)
best_mad_1 = np.array(best_mad)
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]

zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
scale = scale.astype(np.float64)
q_weight = np.empty_like(data, dtype=scale.dtype)
np.divide(data, scale, out=q_weight)
np.add(q_weight, zero_point, out=q_weight)
np.round(q_weight, out=q_weight)
np.clip(q_weight, minq, maxq, out=q_weight)

return q_weight, scale, zero_point


def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
"""Quantize tensor per group based on k quant.

Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c

Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.

Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
try:
import cupy as cp
import torch

if torch.cuda.is_available():
data = cp.asarray(data)
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
maxq = 2**num_bits - 1
minq = 0
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
mask = rmin != rmax
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
scale = 1 / iscale
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
diff = scale * quant_data + rmin - data # (nb, group_size)
best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
nstep = 20
rdelta = 0.1
rrmin = -1
for is_ in range(nstep):
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
mask = rmin != rmax
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
mul_weights_quant_data_new = weights * quant_data_new
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)

this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)

diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)

mad_1 = cp.array(mad)
best_mad_1 = cp.array(best_mad)
idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]

zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
scale = scale.astype(cp.float64)
q_weight = cp.empty_like(data, dtype=scale.dtype)
cp.divide(data, scale, out=q_weight)
cp.add(q_weight, zero_point, out=q_weight)
cp.round(q_weight, out=q_weight)
cp.clip(q_weight, minq, maxq, out=q_weight)

return q_weight.get(), scale.get(), zero_point.get()
else:
logger.warning(
"Try to use k-quant quantization on CUDA. However, CUDA is not available."
"Fall back to k-quant quantization on CPU."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
except ImportError:
logger.info(
"Now we are using k-quant quantization on cpu, which is time consuming."
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
"Please also install torch to check CUDA availability."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)


def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
"""Quant dequant tensor per group.

Expand Down Expand Up @@ -299,6 +469,7 @@ def rtn_quantize(
ratios={},
accuracy_level=0,
providers=["CPUExecutionProvider"],
algorithm="rtn",
):
"""Quant the model with round to nearst method.

Expand Down Expand Up @@ -362,7 +533,10 @@ def rtn_quantize(

weight = pad_tensor(weight, group_size, k_blocks)

satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4
enable_MatMulNBits_8bits = True
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
enable_MatMulNBits_8bits and num_bits == 8
)
satisfy_MatMulFpQ4_condition = (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
)
Expand All @@ -372,9 +546,13 @@ def rtn_quantize(
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
q_weight, scale, zp = quant_tensor(
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
)
if algorithm == "k_quant":
q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
else:
q_weight, scale, zp = quant_tensor(
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
)

q_matmul_node, new_inits = make_matmul_weight_only_node(
node=node,
weight_shape=org_w_shape,
Expand Down
Loading