Skip to content

Commit 4f460db

Browse files
authored
[CP2.2] Machete support group scale & wint8 & v1 loader (#4166)
* support v1 loader for machete (#3999) * [Optimize] Support WINT8 and group scale for Machete (#3905) * [Optimize] Machete using group scale default (#4121)
1 parent 74d7b91 commit 4f460db

File tree

5 files changed

+166
-82
lines changed

5 files changed

+166
-82
lines changed

custom_ops/gpu_ops/machete/machete_mm.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
3030
std::optional<paddle::Tensor> const& maybe_token_scales,
3131
std::string maybe_schedule) {
3232
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
33-
std::optional<int64_t> maybe_group_size_opt;
33+
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
3434
std::optional<std::string> maybe_schedule_opt;
3535
if (maybe_schedule == "") {
3636
maybe_schedule_opt = std::nullopt;
37+
} else {
38+
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
3739
}
3840
return machete::mm_dispatch({.A = A,
3941
.B = B,
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
6365
paddle::DataType maybe_out_type;
6466
if (b_type_str == "uint4b8") {
6567
b_type_id = machete::kU4B8.id();
68+
} else if (b_type_str == "uint8b128") {
69+
b_type_id = machete::kU8B128.id();
6670
} else {
6771
PADDLE_ENFORCE(false, "b_type_str not supported!");
6872
}

custom_ops/gpu_ops/machete/machete_prepack_B.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
5151

5252
if (b_type_str == "uint4b8") {
5353
b_type_id = machete::kU4B8.id();
54+
} else if (b_type_str == "uint8b128") {
55+
b_type_id = machete::kU8B128.id();
5456
} else {
5557
PADDLE_ENFORCE(false, "b_type_str not supported!");
5658
}

fastdeploy/model_executor/layers/quantization/ops/machete_mm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def quantize_weights(
8585
w_s: Scales (None if `group_size` is None).
8686
"""
8787
assert paddle.is_floating_point(w), "w must be float type"
88-
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
88+
assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128"
8989

9090
orig_device = w.place
9191
size_k, size_n = w.shape
@@ -103,8 +103,12 @@ def quantize_weights(
103103
max_val = paddle.max(w, axis=0, keepdim=True)
104104
min_val = paddle.min(w, axis=0, keepdim=True)
105105

106-
max_q_val = float(7.0)
107-
min_q_val = float(-8.0)
106+
if quant_type == "uint4b8":
107+
max_q_val = float(7.0)
108+
min_q_val = float(-8.0)
109+
else:
110+
max_q_val = float(127.0)
111+
min_q_val = float(-128.0)
108112

109113
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
110114

@@ -124,18 +128,20 @@ def quantize_weights(
124128
# w_q += quant_type.bias
125129
if quant_type == "uint4b8":
126130
w_q += 8
131+
else:
132+
w_q += 128
127133

128134
# Restore original shapes
129135
if group_size is not None and group_size < size_k:
130136

131137
def reshape_w(w_tensor):
132138
w_tensor = w_tensor.reshape([group_size, -1, size_n])
133139
w_tensor = w_tensor.transpose([1, 0, 2])
134-
w_tensor = w_tensor.reshape([size_k, size_n])
140+
w_tensor = w_tensor.reshape([size_k, size_n]).contiguous()
135141
return w_tensor
136142

137143
w_q = reshape_w(w_q)
138-
w_s = w_s.reshape([-1, size_n])
144+
w_s = w_s.reshape([-1, size_n]).contiguous()
139145

140146
# Move tensors back to original device
141147
w_q = w_q.to(orig_device)
@@ -153,7 +159,8 @@ def machete_quantize_and_pack(
153159
group_size: int = -1,
154160
):
155161
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
156-
w_q = pack_rows(w_q, 4, *w_q.shape)
162+
num_bits = 4 if quant_type == "uint4b8" else 8
163+
w_q = pack_rows(w_q, num_bits, *w_q.shape)
157164
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
158165
w_q_prepack = machete_prepack_B(
159166
w_q_col,

fastdeploy/model_executor/layers/quantization/weight_only.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
141141
)
142142

143143
if (
144-
self.name() == "wint4"
145-
and _ENABLE_MACHETE
144+
_ENABLE_MACHETE
146145
and envs.FD_USE_MACHETE == "1"
147146
and layer.weight_shape[1]
148147
and layer.weight_shape[1] % 128 == 0
@@ -219,12 +218,22 @@ def create_weights(self, layer, **extra_weight_attrs):
219218
quant_attrs,
220219
)
221220
else:
222-
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
223-
weight_scale_shape = [layer.weight_shape[1]]
224-
layer.weight_shape.reverse()
225-
if self.quant_config.name() == "wint4":
226-
layer.weight_shape[0] //= 2
227-
layer.weight_dtype = "int8"
221+
if isinstance(self, MacheteWeightOnlyLinearMethod):
222+
# Using group scale for machete, group size is 128
223+
weight_scale_shape = [(layer.weight_shape[0] + 127) // 128, layer.weight_shape[1]]
224+
if self.quant_config.name() == "wint4":
225+
layer.weight_shape[0] //= 8
226+
else:
227+
layer.weight_shape[0] //= 4
228+
layer.weight_dtype = "int32"
229+
else:
230+
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
231+
weight_scale_shape = [layer.weight_shape[1]]
232+
layer.weight_shape.reverse()
233+
if self.quant_config.name() == "wint4":
234+
layer.weight_shape[0] //= 2
235+
layer.weight_dtype = "int8"
236+
228237
layer.weight = layer.create_parameter(
229238
shape=layer.weight_shape,
230239
dtype=layer.weight_dtype,
@@ -260,17 +269,30 @@ def create_weights(self, layer, **extra_weight_attrs):
260269
def process_weights_after_loading(self, layer) -> None:
261270
if not layer.fd_config.load_config.load_choices == "default_v1":
262271
return
263-
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
264-
layer.weight,
265-
algo=self.quant_config.algo,
266-
arch=self.quant_config.weight_only_linear_arch,
267-
)
272+
if isinstance(self, MacheteWeightOnlyLinearMethod):
273+
from fastdeploy.model_executor.layers.quantization.ops import (
274+
machete_quantize_and_pack,
275+
)
276+
277+
# Using group scale for machete, group size is 128
278+
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
279+
w=layer.weight,
280+
atype=layer._dtype,
281+
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
282+
group_size=128,
283+
)
284+
else:
285+
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
286+
layer.weight,
287+
algo=self.quant_config.algo,
288+
arch=self.quant_config.weight_only_linear_arch,
289+
)
268290

269291
free_tensor(layer.weight)
270292

271293
layer.weight = layer.create_parameter(
272294
shape=quanted_weight_tensor.shape,
273-
dtype="int8",
295+
dtype="int8" if not isinstance(self, MacheteWeightOnlyLinearMethod) else "int32",
274296
is_bias=False,
275297
default_initializer=paddle.nn.initializer.Constant(0),
276298
)
@@ -361,32 +383,6 @@ def __init__(
361383
) -> None:
362384
super().__init__(quant_config)
363385

364-
def create_weights(self, layer, **extra_weight_attrs):
365-
366-
assert layer.bias is None, "Machete weight only linear method does not support bias."
367-
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
368-
369-
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
370-
weight_scale_shape = [1, layer.weight_shape[1]]
371-
372-
# layer.weight_shape.reverse()
373-
if self.quant_config.name() == "wint4":
374-
layer.weight_shape[0] //= 8
375-
layer.weight_dtype = "int32"
376-
377-
layer.weight = layer.create_parameter(
378-
shape=layer.weight_shape,
379-
dtype=layer.weight_dtype,
380-
is_bias=False,
381-
default_initializer=paddle.nn.initializer.Constant(0),
382-
)
383-
384-
layer.weight_scale = layer.create_parameter(
385-
shape=weight_scale_shape,
386-
dtype=layer._dtype,
387-
is_bias=False,
388-
)
389-
390386
def process_prequanted_weights(self, layer, state_dict) -> None:
391387
pass
392388

@@ -395,24 +391,27 @@ def process_loaded_weights(self, layer, weight) -> None:
395391
machete_quantize_and_pack,
396392
)
397393

394+
# Using group scale for machete, group size is 128
398395
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
399396
w=weight,
400397
atype=layer._dtype,
401-
quant_type="uint4b8",
398+
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
399+
group_size=128,
402400
)
403401
layer.weight.set_value(quanted_weight_tensor)
404402
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
405403

406404
def apply(self, layer, x):
407-
assert layer.bias is None, "Machete weight only linear method does not support bias."
408-
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
409405
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
410406

407+
# Using group scale for machete, group size is 128
411408
linear_out = machete_wint_mm(
412409
x,
413410
w_prepack=layer.weight,
414411
w_g_s=layer.weight_scale,
415-
weight_dtype="uint4b8",
412+
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
413+
group_size=128,
416414
)
417-
415+
if layer.with_bias:
416+
linear_out = paddle.add(linear_out, layer.bias)
418417
return linear_out

0 commit comments

Comments
 (0)