@@ -376,9 +376,9 @@ def fn(match):
376376 return fn
377377
378378
379- def _is_valid_qconv_post_op_fusion_pattern (has_binary_post_op = False ):
379+ def _is_valid_qconv_post_op_fusion_pattern (has_binary_post_op = False , is_fp8 = False ):
380380 return (
381- _is_valid_qconv_binary_optimization_pattern ()
381+ _is_valid_qconv_binary_optimization_pattern (is_fp8 = is_fp8 )
382382 if has_binary_post_op
383383 else _is_valid_quantized_conv_optimization_pattern ()
384384 )
@@ -408,9 +408,11 @@ def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False):
408408 )
409409
410410
411- def _is_valid_qconv_binary_optimization_pattern ():
411+ def _is_valid_qconv_binary_optimization_pattern (is_fp8 = False ):
412412 return _is_valid_quantized_op_binary_optimization_pattern (
413- torch .ops .onednn .qconv_pointwise
413+ torch .ops .onednn .qconv_pointwise ,
414+ # we don't insert q-dq for extra input in fp8 recipe
415+ extra_input_from_dequant = not is_fp8 ,
414416 )
415417
416418
@@ -2016,12 +2018,13 @@ def _register_qconv_post_op_fusion_pass(
20162018 pass_number ,
20172019 computation_op ,
20182020 post_op_attr ,
2021+ is_fp8 = False ,
20192022):
20202023 has_binary_post_op = post_op_attr .binary_op_name != "none"
20212024
20222025 @register_freezing_graph_pattern (
20232026 pattern ,
2024- extra_check = _is_valid_qconv_post_op_fusion_pattern (has_binary_post_op ),
2027+ extra_check = _is_valid_qconv_post_op_fusion_pattern (has_binary_post_op , is_fp8 = is_fp8 ),
20252028 pass_number = pass_number ,
20262029 )
20272030 def qconv (match : Match , * args , ** kwargs ):
@@ -2097,7 +2100,7 @@ def qconv(match: Match, *args, **kwargs):
20972100 else :
20982101 accum = (
20992102 kwargs ["accum" ]
2100- if output_dtype in [torch .uint8 , torch .int8 ]
2103+ if output_dtype in [torch .uint8 , torch .int8 ] or is_fp8
21012104 else kwargs ["accum_after_dequant" ]
21022105 )
21032106 accum_scale = (
@@ -2237,6 +2240,7 @@ def _register_qconv_unary_fusion():
22372240 3 , # pass_number
22382241 computation_op , # computation_op
22392242 unary_attr , # unary_attr
2243+ is_fp8 = is_fp8 ,
22402244 )
22412245
22422246 # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
@@ -2289,41 +2293,49 @@ def _register_qconv_unary_fusion():
22892293 4 , # pass_number
22902294 computation_op , # computation_op
22912295 unary_attr , # unary_attr
2296+ is_fp8 = is_fp8 ,
22922297 )
22932298
22942299
22952300def _register_qconv_binary_fusion ():
2296- for int8_mixed_bf16_with_inplace_add in [False , True ]:
2301+ for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product ([False , True ], [False , True ]):
2302+ qconv_binary_op = (
2303+ torch .ops .onednn .qconv2d_pointwise .binary_tensor
2304+ if x_scale_zp_are_tensors
2305+ else torch .ops .onednn .qconv2d_pointwise .binary
2306+ )
22972307 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
22982308 swap_binary_inputs_list = [False , True ]
22992309 binary_replace_patterns = {}
2300- for swap_inputs in swap_binary_inputs_list :
2310+ for swap_inputs , is_fp8 in itertools . product ( swap_binary_inputs_list , [ False , True ]) :
23012311 binary_replace_patterns .update (
23022312 {
23032313 PostOpAttr (
23042314 "sum" , 1.0 , "none" , [], ""
23052315 ): generate_pattern_with_output_quant (
23062316 generate_pattern_with_binary (
23072317 aten .add .Tensor ,
2308- get_qconv_pt2e_pattern (users = 1 ),
2318+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23092319 dequantize_accum_pattern ,
23102320 int8_mixed_bf16_with_inplace_add ,
23112321 swap_inputs = swap_inputs ,
23122322 ),
2323+ is_fp8 = is_fp8 ,
23132324 ),
23142325 PostOpAttr (
23152326 "sum" , 1.0 , "relu" , [], ""
23162327 ): generate_pattern_with_output_quant (
23172328 generate_pattern_with_unary (
23182329 generate_pattern_with_binary (
23192330 aten .add .Tensor ,
2320- get_qconv_pt2e_pattern (users = 1 ),
2331+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23212332 dequantize_accum_pattern ,
23222333 int8_mixed_bf16_with_inplace_add ,
23232334 swap_inputs = swap_inputs ,
23242335 ),
23252336 aten .relu .default ,
23262337 ),
2338+ is_fp8 = is_fp8 ,
23272339 ),
23282340 }
23292341 )
@@ -2332,8 +2344,9 @@ def _register_qconv_binary_fusion():
23322344 _register_qconv_post_op_fusion_pass (
23332345 patterns ,
23342346 3 , # pass_number
2335- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2347+ qconv_binary_op , # computation_op
23362348 binary_unary_attr , # binary_unary_attr
2349+ is_fp8 = is_fp8 ,
23372350 )
23382351
23392352 # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
@@ -2344,8 +2357,8 @@ def _register_qconv_binary_fusion():
23442357 PostOpAttr ("sum" , 1.0 , "relu" , [], "" ): generate_pattern_with_unary (
23452358 generate_pattern_with_binary (
23462359 aten .add .Tensor ,
2347- get_qconv_pt2e_pattern (users = 1 ),
2348- KeywordArg ("accum_after_dequant" ),
2360+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
2361+ KeywordArg ("accum" ) if is_fp8 else KeywordArg ( " accum_after_dequant" ),
23492362 int8_mixed_bf16_with_inplace_add ,
23502363 swap_inputs = swap_inputs ,
23512364 ),
@@ -2362,15 +2375,17 @@ def _register_qconv_binary_fusion():
23622375 _register_qconv_post_op_fusion_pass (
23632376 patterns ,
23642377 3 , # pass_number
2365- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2378+ qconv_binary_op , # computation_op
23662379 binary_unary_attr , # binary_unary_attr
2380+ is_fp8 = is_fp8 ,
23672381 )
23682382 else :
23692383 _register_qconv_post_op_fusion_pass (
23702384 patterns ,
23712385 4 , # pass_number
2372- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2386+ qconv_binary_op , # computation_op
23732387 binary_unary_attr , # binary_unary_attr
2388+ is_fp8 = is_fp8 ,
23742389 )
23752390
23762391 # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
@@ -2382,8 +2397,8 @@ def _register_qconv_binary_fusion():
23822397 "sum" , 1.0 , "none" , [], ""
23832398 ): generate_pattern_with_binary (
23842399 aten .add .Tensor ,
2385- get_qconv_pt2e_pattern (users = 1 ),
2386- KeywordArg ("accum_after_dequant" ),
2400+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
2401+ KeywordArg ("accum" ) if is_fp8 else KeywordArg ( " accum_after_dequant" ),
23872402 int8_mixed_bf16_with_inplace_add ,
23882403 swap_inputs = swap_inputs ,
23892404 ),
@@ -2397,8 +2412,9 @@ def _register_qconv_binary_fusion():
23972412 _register_qconv_post_op_fusion_pass (
23982413 patterns ,
23992414 4 if int8_mixed_bf16_with_inplace_add else 5 , # pass_number
2400- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2415+ qconv_binary_op , # computation_op
24012416 binary_unary_attr , # binary_unary_attr
2417+ is_fp8 = is_fp8 ,
24022418 )
24032419
24042420
0 commit comments