Skip to content

Commit a5994ac

Browse files
authored
[XNNPACK] Serialize weights as fp16 rather than fp32 (#9753)
### Summary Previously we've used FP32_STATIC_WEIGHTS flag in xnnpack to coerce fp32 weights into fp16 for linear and conv. This allowed us to mimc fp16 computation because the weights would be converted and packed as fp16 at runtime. However, this means we lose the benefit of the smaller .pte file because the weights are serialized as fp32 rather than fp16. Additionally, we still have to load the weights as fp32, since they are converted at runtime. This has some poor effects on performance ### Test plan ``` python -m unittest backends.xnnpack.test.ops.test_linear.TestLinear.test_fp16_linear python -m unittest backends.xnnpack.test.ops.test_linear.TestLinear python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d ``` Llama 3.2 with bf16 weights: Before: ``` -rw-r--r-- 1 maxren staff 5468937344 Mar 28 17:00 llama3_2_fp16_direct_convert_runtime.pte ``` After: ``` -rw-r--r-- 1 maxren staff 2997443712 Mar 28 16:57 llama3_2_fp16_direct_convert_runtime.pte ```
1 parent 16e5901 commit a5994ac

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

backends/xnnpack/operators/node_visitor.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def get_serialized_dtype(
210210
self,
211211
quant_params: Optional[QuantParams],
212212
node: torch.fx.Node,
213-
fp32_static_weight: bool = False,
213+
force_fp32: bool = False,
214214
) -> XNNDatatype:
215215
# Default initialization
216216
dtype = XNNDatatype.xnn_datatype_fp32
@@ -267,7 +267,7 @@ def get_per_channel_dtype(
267267
if node_dtype is not None and node_dtype == torch.float16:
268268
dtype = (
269269
XNNDatatype.xnn_datatype_fp32
270-
if fp32_static_weight
270+
if force_fp32
271271
else XNNDatatype.xnn_datatype_fp16
272272
)
273273

@@ -348,7 +348,7 @@ def define_tensor( # noqa: C901
348348
convert_to_nhwc: bool = False,
349349
swap_in_out_for_weights: bool = False,
350350
quant_params: Optional[QuantParams] = None,
351-
fp32_static_weights: bool = False,
351+
force_fp32: bool = False,
352352
groups: int = 1,
353353
) -> None:
354354
"""
@@ -368,7 +368,7 @@ def define_tensor( # noqa: C901
368368
constant data. If used along with convert_to_nhwc, this
369369
swap will happen before converting to nhwc.
370370
quant_params: Quantization meta data for this tensor, None if it is not quantized
371-
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
371+
force_fp32: forces tensor to be serialize as fp32, used for bias of dynamically quantized ops
372372
groups: number of groups for swap_in_out_for_weights
373373
"""
374374

@@ -405,7 +405,7 @@ def define_tensor( # noqa: C901
405405
convert_to_nhwc,
406406
swap_in_out_for_weights,
407407
quant_params,
408-
fp32_static_weights,
408+
force_fp32,
409409
groups,
410410
)
411411

@@ -417,9 +417,7 @@ def define_tensor( # noqa: C901
417417
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
418418
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
419419

420-
dtype = self.get_serialized_dtype(
421-
quant_params, tensor, fp32_static_weight=fp32_static_weights
422-
)
420+
dtype = self.get_serialized_dtype(quant_params, tensor, force_fp32=force_fp32)
423421

424422
tvalue = XNNTensorValue(
425423
datatype=dtype,
@@ -504,7 +502,7 @@ def get_serialized_buffer_index(
504502
convert_to_nhwc: bool,
505503
swap_in_out_for_weights: bool,
506504
quant_params: Optional[QuantParams],
507-
fp32_static_weights: bool = False,
505+
force_fp32: bool = False,
508506
groups: int = 1,
509507
) -> int:
510508
"""
@@ -525,7 +523,7 @@ def get_serialized_buffer_index(
525523
constant data. If used along with convert_to_nhwc, this
526524
swap will happen before converting to nhwc.
527525
quant_params: Quantization meta data for this tensor, None if it is not quantize
528-
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+
force_fp32: bool to indicate whether tensor is fp32 static weights
529527
groups: groups for swap_in_out_for_weights
530528
531529
Returns:
@@ -554,7 +552,7 @@ def get_serialized_buffer_index(
554552
# Quantize buffer if static data is indeed quantized
555553
if quant_params is not None and not quant_params.is_dynamic:
556554
const_val = quant_params.quantize_tensor(const_val).contiguous()
557-
elif const_val.dtype != torch.float16 or fp32_static_weights:
555+
elif const_val.dtype != torch.float16 or force_fp32:
558556
# ensure that the const is fp32
559557
const_val = const_val.to(dtype=torch.float32).contiguous()
560558

backends/xnnpack/operators/op_conv2d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def define_node(
8282
weight_quant_params = QuantParams.from_weights(
8383
kernel_node, self._exported_program
8484
)
85-
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
8685

8786
if weight_quant_params is not None and weight_quant_params.per_channel:
8887
if is_transpose:
@@ -102,8 +101,8 @@ def define_node(
102101
convert_to_nhwc=True,
103102
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
104103
quant_params=weight_quant_params,
105-
fp32_static_weights=fp32_static_weights,
106104
groups=groups if is_transpose else 1,
105+
force_fp32=True,
107106
)
108107
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
109108

@@ -127,13 +126,14 @@ def define_node(
127126
bias_quant_params = QuantParams.from_bias(
128127
bias_node, weight_quant_params, input_quant_params
129128
)
129+
130130
self.define_tensor(
131131
get_input_node(node, 2),
132132
xnn_graph,
133133
vals_to_ids,
134134
convert_to_nhwc=False,
135135
quant_params=bias_quant_params,
136-
fp32_static_weights=fp32_static_weights,
136+
force_fp32=True,
137137
)
138138
kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)]
139139

backends/xnnpack/operators/op_linear.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def define_node(
5959
xnn_graph,
6060
vals_to_ids,
6161
quant_params=weight_quant_params,
62-
fp32_static_weights=True,
6362
)
6463
filter_id = vals_to_ids[weight_node]
6564

@@ -69,12 +68,18 @@ def define_node(
6968
bias_quant_params = QuantParams.from_bias(
7069
bias_node, weight_quant_params, input_quant_params
7170
)
71+
# For dynamic quantization, there are no kernels with fp16 bias
72+
# So we need to force the fp16 bias to fp32
73+
force_fp32 = False
74+
if input_quant_params is not None and input_quant_params.is_dynamic:
75+
force_fp32 = True
76+
7277
self.define_tensor(
7378
get_input_node(node, 2),
7479
xnn_graph,
7580
vals_to_ids,
7681
quant_params=bias_quant_params,
77-
fp32_static_weights=True,
82+
force_fp32=force_fp32,
7883
)
7984
bias_id = vals_to_ids[bias_node]
8085
else:

backends/xnnpack/test/ops/test_linear.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
605605

606606
if legacy_partitioner:
607607
tester.to_edge()
608-
tester.partition(
609-
Partition(DynamicallyQuantizedPartitioner)
610-
).dump_artifact()
608+
tester.partition(Partition(DynamicallyQuantizedPartitioner))
611609
# should have [add]mm node
612610
if uses_bias:
613611
tester.check(
@@ -624,7 +622,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
624622
else:
625623
tester.to_edge_transform_and_lower(
626624
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
627-
).dump_artifact()
625+
)
628626
# should not have a delegate node
629627
tester.check_not(
630628
[
@@ -717,7 +715,7 @@ def test_fp16_linear(self):
717715
num_batch_dims=num_batch_dims,
718716
uses_bias=use_bias,
719717
dtype=torch.float16,
720-
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
718+
atol=5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs
721719
)
722720

723721
def test_fp32_linear(self):

0 commit comments

Comments
 (0)