Skip to content

Commit 554cd27

Browse files
chunit-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Enable per channel linear op (#2822)
Summary: - Add per channel weight quantization for linear op - Bias quantization for per channel weight Linear op is not support yet Pull Request resolved: #2822 Reviewed By: kirklandsign Differential Revision: D55731629 Pulled By: cccclai fbshipit-source-id: 831a47c897b34e1a749325df56a8bbd0acda80e1
1 parent 26365f1 commit 554cd27

File tree

6 files changed

+69
-17
lines changed

6 files changed

+69
-17
lines changed

backends/qualcomm/builders/op_linear.py

+14
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def define_node(
4040
linear_input_tensors.append(input_tensor_wrapper)
4141

4242
weight_node = node.args[1]
43+
if (
44+
quant_attrs := weight_node.meta.get("quant_attrs")
45+
) and "scales" in quant_attrs:
46+
# Dimension of weight is [m, n], per channel quant params is [m]
47+
# Change to [m, 1] to fit the tensor.div(s).add(z)
48+
quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1])
49+
quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1])
50+
4351
weight_tensor = get_parameter(weight_node, self.edge_program)
4452
weight_tensor_wrapper = self.define_tensor(
4553
weight_node,
@@ -52,6 +60,12 @@ def define_node(
5260

5361
if len(node.args) >= 3:
5462
bias_node = node.args[2]
63+
64+
# TODO remove this when qnn sdk support
65+
if "scales" in bias_node.meta.get("quant_attrs"):
66+
print(
67+
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
68+
)
5569
bias_tensor = get_parameter(bias_node, self.edge_program)
5670
bias_tensor_wrapper = self.define_tensor(
5771
bias_node,

backends/qualcomm/quantizer/quantizer.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __init__(self):
267267
self.custom_quant_annotations: Sequence[Callable] = []
268268
self.discard_nodes: Set[str] = set()
269269

270-
self.enable_per_channel_conv_quant: bool = True
270+
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
271271
# the weight quantized for activation 8 bits and 16 bits
272272
self.per_channel_weight_dtype: Dict = {
273273
"8bit_act": torch.int8,
@@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
290290
def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
291291
"""
292292
Priority:
293-
1. per channel config when enable_per_channel_conv_quant is True
293+
1. is one of use_per_channel_weight_quant_ops
294294
2. int8 / int16 config
295295
"""
296296
if type(op) == str:
297297
return
298298

299-
if self.enable_per_channel_conv_quant and op in [
300-
torch.ops.aten.conv1d.default,
301-
torch.ops.aten.conv2d.default,
302-
]:
299+
if op in self.use_per_channel_weight_quant_ops:
303300
if op in self.bit16_quant_ops:
304301
return get_ptq_per_channel_weight_config(
305302
torch.uint16, self.per_channel_weight_dtype["16bit_act"]
@@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
316313

317314
print(f"No quant config is implemented for op, {op}")
318315

316+
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
317+
if enable:
318+
self.use_per_channel_weight_quant_ops.update(ops)
319+
else:
320+
self.use_per_channel_weight_quant_ops.difference(ops)
321+
319322
def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None:
320323
for op in ops:
321324
assert (
@@ -368,8 +371,15 @@ def set_per_channel_weight_dtype(
368371
if weight_dtype_for_16bit_act:
369372
self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act
370373

371-
def set_per_channel_quant(self, enable: bool) -> None:
372-
self.enable_per_channel_conv_quant = enable
374+
def set_per_channel_conv_quant(self, enable: bool) -> None:
375+
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
376+
self._update_per_channel_weight_quant_ops(conv_ops, enable)
377+
378+
def set_per_channel_linear_quant(self, enable: bool) -> None:
379+
linear_ops = {
380+
torch.ops.aten.linear.default,
381+
}
382+
self._update_per_channel_weight_quant_ops(linear_ops, enable)
373383

374384
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
375385
model = RemoveClone()(model).graph_module

backends/qualcomm/quantizer/utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,11 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
520520
)
521521
nodes_to_mark_annotated = [node, weight_node]
522522
if bias_node:
523-
_annotate_input_qspec_map(
524-
node,
525-
bias_node,
526-
quantization_config.bias,
527-
)
523+
if callable(quantization_config.bias):
524+
bias_config = quantization_config.bias(node)
525+
else:
526+
bias_config = quantization_config.bias
527+
_annotate_input_qspec_map(node, bias_node, bias_config)
528528
nodes_to_mark_annotated.append(bias_node)
529529
_annotate_output_qspec(node, quantization_config.output_activation)
530530
_mark_nodes_as_annotated(nodes_to_mark_annotated)

backends/qualcomm/tests/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(self, x):
409409

410410

411411
class Linear(torch.nn.Module):
412-
def __init__(self):
412+
def __init__(self, use_bias: bool = True):
413413
super().__init__()
414-
self.linear = torch.nn.Linear(4, 5).eval()
414+
self.linear = torch.nn.Linear(4, 5, use_bias).eval()
415415

416416
def forward(self, x):
417417
return self.linear(x)

backends/qualcomm/tests/test_qnn_delegate.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,33 @@ def test_qnn_backend_16a4w_linear(self):
505505
module = Linear() # noqa: F405
506506
sample_input = (torch.randn([3, 4]),)
507507
module = self.get_qdq_module(
508-
module, sample_input, quant_dtype=QuantDtype.use_16a4w
508+
module,
509+
sample_input,
510+
quant_dtype=QuantDtype.use_16a4w,
511+
)
512+
self.lower_module_and_test_output(module, sample_input)
513+
514+
def test_qnn_backend_16a4w_per_channel_linear(self):
515+
module = Linear(use_bias=False) # noqa: F405
516+
sample_input = (torch.randn([3, 4]),)
517+
module = self.get_qdq_module(
518+
module,
519+
sample_input,
520+
is_linear_per_channel=True,
521+
quant_dtype=QuantDtype.use_16a4w,
522+
)
523+
self.lower_module_and_test_output(module, sample_input)
524+
525+
# Is not enabled in the current qnn sdk release
526+
@unittest.expectedFailure
527+
def test_qnn_backend_16a4w_per_channel_linear_with_bias(self):
528+
module = Linear() # noqa: F405
529+
sample_input = (torch.randn([3, 4]),)
530+
module = self.get_qdq_module(
531+
module,
532+
sample_input,
533+
is_linear_per_channel=True,
534+
quant_dtype=QuantDtype.use_16a4w,
509535
)
510536
self.lower_module_and_test_output(module, sample_input)
511537

backends/qualcomm/tests/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,16 @@ def get_qdq_module(
225225
module: torch.nn.Module,
226226
inputs: Tuple[torch.Tensor],
227227
is_conv_per_channel: Optional[bool] = True,
228+
is_linear_per_channel: Optional[bool] = False,
228229
custom_quant_annotations: Tuple[Callable] = (),
229230
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
230231
) -> torch.fx.GraphModule:
231232
m = torch._export.capture_pre_autograd_graph(module, inputs)
232233

233234
quantizer = QnnQuantizer()
234235
quantizer.add_custom_quant_annotations(custom_quant_annotations)
235-
quantizer.set_per_channel_quant(is_conv_per_channel)
236+
quantizer.set_per_channel_conv_quant(is_conv_per_channel)
237+
quantizer.set_per_channel_linear_quant(is_linear_per_channel)
236238

237239
if quant_dtype == QuantDtype.use_8a8w:
238240
pass # default setting

0 commit comments

Comments
 (0)