Skip to content

Commit 23bd2b8

Browse files
committed
[Inductor][float8] Register qconv-binary fusion pass for float8
1 parent 0160379 commit 23bd2b8

File tree

1 file changed

+34
-18
lines changed
  • torchao/quantization/pt2e/inductor_passes

1 file changed

+34
-18
lines changed

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,9 @@ def fn(match):
376376
return fn
377377

378378

379-
def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False):
379+
def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False, is_fp8=False):
380380
return (
381-
_is_valid_qconv_binary_optimization_pattern()
381+
_is_valid_qconv_binary_optimization_pattern(is_fp8=is_fp8)
382382
if has_binary_post_op
383383
else _is_valid_quantized_conv_optimization_pattern()
384384
)
@@ -408,9 +408,11 @@ def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False):
408408
)
409409

410410

411-
def _is_valid_qconv_binary_optimization_pattern():
411+
def _is_valid_qconv_binary_optimization_pattern(is_fp8=False):
412412
return _is_valid_quantized_op_binary_optimization_pattern(
413-
torch.ops.onednn.qconv_pointwise
413+
torch.ops.onednn.qconv_pointwise,
414+
# we don't insert q-dq for extra input in fp8 recipe
415+
extra_input_from_dequant= not is_fp8,
414416
)
415417

416418

@@ -2016,12 +2018,13 @@ def _register_qconv_post_op_fusion_pass(
20162018
pass_number,
20172019
computation_op,
20182020
post_op_attr,
2021+
is_fp8=False,
20192022
):
20202023
has_binary_post_op = post_op_attr.binary_op_name != "none"
20212024

20222025
@register_freezing_graph_pattern(
20232026
pattern,
2024-
extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op),
2027+
extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op, is_fp8=is_fp8),
20252028
pass_number=pass_number,
20262029
)
20272030
def qconv(match: Match, *args, **kwargs):
@@ -2097,7 +2100,7 @@ def qconv(match: Match, *args, **kwargs):
20972100
else:
20982101
accum = (
20992102
kwargs["accum"]
2100-
if output_dtype in [torch.uint8, torch.int8]
2103+
if output_dtype in [torch.uint8, torch.int8] or is_fp8
21012104
else kwargs["accum_after_dequant"]
21022105
)
21032106
accum_scale = (
@@ -2237,6 +2240,7 @@ def _register_qconv_unary_fusion():
22372240
3, # pass_number
22382241
computation_op, # computation_op
22392242
unary_attr, # unary_attr
2243+
is_fp8=is_fp8,
22402244
)
22412245

22422246
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
@@ -2289,41 +2293,49 @@ def _register_qconv_unary_fusion():
22892293
4, # pass_number
22902294
computation_op, # computation_op
22912295
unary_attr, # unary_attr
2296+
is_fp8=is_fp8,
22922297
)
22932298

22942299

22952300
def _register_qconv_binary_fusion():
2296-
for int8_mixed_bf16_with_inplace_add in [False, True]:
2301+
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
2302+
qconv_binary_op = (
2303+
torch.ops.onednn.qconv2d_pointwise.binary_tensor
2304+
if x_scale_zp_are_tensors
2305+
else torch.ops.onednn.qconv2d_pointwise.binary
2306+
)
22972307
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
22982308
swap_binary_inputs_list = [False, True]
22992309
binary_replace_patterns = {}
2300-
for swap_inputs in swap_binary_inputs_list:
2310+
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
23012311
binary_replace_patterns.update(
23022312
{
23032313
PostOpAttr(
23042314
"sum", 1.0, "none", [], ""
23052315
): generate_pattern_with_output_quant(
23062316
generate_pattern_with_binary(
23072317
aten.add.Tensor,
2308-
get_qconv_pt2e_pattern(users=1),
2318+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23092319
dequantize_accum_pattern,
23102320
int8_mixed_bf16_with_inplace_add,
23112321
swap_inputs=swap_inputs,
23122322
),
2323+
is_fp8=is_fp8,
23132324
),
23142325
PostOpAttr(
23152326
"sum", 1.0, "relu", [], ""
23162327
): generate_pattern_with_output_quant(
23172328
generate_pattern_with_unary(
23182329
generate_pattern_with_binary(
23192330
aten.add.Tensor,
2320-
get_qconv_pt2e_pattern(users=1),
2331+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23212332
dequantize_accum_pattern,
23222333
int8_mixed_bf16_with_inplace_add,
23232334
swap_inputs=swap_inputs,
23242335
),
23252336
aten.relu.default,
23262337
),
2338+
is_fp8=is_fp8,
23272339
),
23282340
}
23292341
)
@@ -2332,8 +2344,9 @@ def _register_qconv_binary_fusion():
23322344
_register_qconv_post_op_fusion_pass(
23332345
patterns,
23342346
3, # pass_number
2335-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2347+
qconv_binary_op, # computation_op
23362348
binary_unary_attr, # binary_unary_attr
2349+
is_fp8=is_fp8,
23372350
)
23382351

23392352
# Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
@@ -2344,8 +2357,8 @@ def _register_qconv_binary_fusion():
23442357
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
23452358
generate_pattern_with_binary(
23462359
aten.add.Tensor,
2347-
get_qconv_pt2e_pattern(users=1),
2348-
KeywordArg("accum_after_dequant"),
2360+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
2361+
KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"),
23492362
int8_mixed_bf16_with_inplace_add,
23502363
swap_inputs=swap_inputs,
23512364
),
@@ -2362,15 +2375,17 @@ def _register_qconv_binary_fusion():
23622375
_register_qconv_post_op_fusion_pass(
23632376
patterns,
23642377
3, # pass_number
2365-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2378+
qconv_binary_op, # computation_op
23662379
binary_unary_attr, # binary_unary_attr
2380+
is_fp8=is_fp8,
23672381
)
23682382
else:
23692383
_register_qconv_post_op_fusion_pass(
23702384
patterns,
23712385
4, # pass_number
2372-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2386+
qconv_binary_op, # computation_op
23732387
binary_unary_attr, # binary_unary_attr
2388+
is_fp8=is_fp8,
23742389
)
23752390

23762391
# Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
@@ -2382,8 +2397,8 @@ def _register_qconv_binary_fusion():
23822397
"sum", 1.0, "none", [], ""
23832398
): generate_pattern_with_binary(
23842399
aten.add.Tensor,
2385-
get_qconv_pt2e_pattern(users=1),
2386-
KeywordArg("accum_after_dequant"),
2400+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
2401+
KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"),
23872402
int8_mixed_bf16_with_inplace_add,
23882403
swap_inputs=swap_inputs,
23892404
),
@@ -2397,8 +2412,9 @@ def _register_qconv_binary_fusion():
23972412
_register_qconv_post_op_fusion_pass(
23982413
patterns,
23992414
4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number
2400-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2415+
qconv_binary_op, # computation_op
24012416
binary_unary_attr, # binary_unary_attr
2417+
is_fp8=is_fp8,
24022418
)
24032419

24042420

0 commit comments

Comments
 (0)