@@ -210,7 +210,7 @@ def get_serialized_dtype(
210
210
self ,
211
211
quant_params : Optional [QuantParams ],
212
212
node : torch .fx .Node ,
213
- fp32_static_weight : bool = False ,
213
+ force_fp32 : bool = False ,
214
214
) -> XNNDatatype :
215
215
# Default initialization
216
216
dtype = XNNDatatype .xnn_datatype_fp32
@@ -267,7 +267,7 @@ def get_per_channel_dtype(
267
267
if node_dtype is not None and node_dtype == torch .float16 :
268
268
dtype = (
269
269
XNNDatatype .xnn_datatype_fp32
270
- if fp32_static_weight
270
+ if force_fp32
271
271
else XNNDatatype .xnn_datatype_fp16
272
272
)
273
273
@@ -348,7 +348,7 @@ def define_tensor( # noqa: C901
348
348
convert_to_nhwc : bool = False ,
349
349
swap_in_out_for_weights : bool = False ,
350
350
quant_params : Optional [QuantParams ] = None ,
351
- fp32_static_weights : bool = False ,
351
+ force_fp32 : bool = False ,
352
352
groups : int = 1 ,
353
353
) -> None :
354
354
"""
@@ -368,7 +368,7 @@ def define_tensor( # noqa: C901
368
368
constant data. If used along with convert_to_nhwc, this
369
369
swap will happen before converting to nhwc.
370
370
quant_params: Quantization meta data for this tensor, None if it is not quantized
371
- fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
371
+ force_fp32: forces tensor to be serialize as fp32, used for bias of dynamically quantized ops
372
372
groups: number of groups for swap_in_out_for_weights
373
373
"""
374
374
@@ -405,7 +405,7 @@ def define_tensor( # noqa: C901
405
405
convert_to_nhwc ,
406
406
swap_in_out_for_weights ,
407
407
quant_params ,
408
- fp32_static_weights ,
408
+ force_fp32 ,
409
409
groups ,
410
410
)
411
411
@@ -417,9 +417,7 @@ def define_tensor( # noqa: C901
417
417
check_or_raise (len (dims ) == 4 , "Converting to nhwc requires 4d tensor" )
418
418
dims = [dims [i ] for i in PERM_NCHW_TO_NHWC ]
419
419
420
- dtype = self .get_serialized_dtype (
421
- quant_params , tensor , fp32_static_weight = fp32_static_weights
422
- )
420
+ dtype = self .get_serialized_dtype (quant_params , tensor , force_fp32 = force_fp32 )
423
421
424
422
tvalue = XNNTensorValue (
425
423
datatype = dtype ,
@@ -504,7 +502,7 @@ def get_serialized_buffer_index(
504
502
convert_to_nhwc : bool ,
505
503
swap_in_out_for_weights : bool ,
506
504
quant_params : Optional [QuantParams ],
507
- fp32_static_weights : bool = False ,
505
+ force_fp32 : bool = False ,
508
506
groups : int = 1 ,
509
507
) -> int :
510
508
"""
@@ -525,7 +523,7 @@ def get_serialized_buffer_index(
525
523
constant data. If used along with convert_to_nhwc, this
526
524
swap will happen before converting to nhwc.
527
525
quant_params: Quantization meta data for this tensor, None if it is not quantize
528
- fp32_static_weights : bool to indicate whether tensor is fp32 static weights
526
+ force_fp32 : bool to indicate whether tensor is fp32 static weights
529
527
groups: groups for swap_in_out_for_weights
530
528
531
529
Returns:
@@ -554,7 +552,7 @@ def get_serialized_buffer_index(
554
552
# Quantize buffer if static data is indeed quantized
555
553
if quant_params is not None and not quant_params .is_dynamic :
556
554
const_val = quant_params .quantize_tensor (const_val ).contiguous ()
557
- elif const_val .dtype != torch .float16 or fp32_static_weights :
555
+ elif const_val .dtype != torch .float16 or force_fp32 :
558
556
# ensure that the const is fp32
559
557
const_val = const_val .to (dtype = torch .float32 ).contiguous ()
560
558
0 commit comments