-
Notifications
You must be signed in to change notification settings - Fork 881
Expand file tree
/
Copy pathfusion_pass.py
More file actions
793 lines (710 loc) · 29.2 KB
/
fusion_pass.py
File metadata and controls
793 lines (710 loc) · 29.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import operator as op_module
from typing import Any, cast, Dict, List, Optional, Tuple
import torch
from executorch.backends.cadence.aot.compiler_utils import get_shape
from executorch.backends.cadence.aot.pass_utils import get_arg
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
BmmPattern,
CatPattern,
Conv1dPattern,
Conv1dReluPattern0,
Conv1dReluPattern1,
Conv2dPattern,
Conv2dReluPattern0,
Conv2dReluPattern1,
LayerNormPattern,
LinearPattern,
MatmulPattern,
MaxPool2dPattern,
MaxPool2dWithoutIndicesPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
ReluPattern0,
ReluPattern1,
SoftmaxPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
copy_node_metadata,
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
quantize_tensor_multiplier,
)
from executorch.exir.pass_base import ExportPass
from torch import fx
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.fuser_utils import legalize_graph
# Use this to avoid pyre errors
# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
ArgsType = Any
# Use this part for patterns with multiple aten ops
ReluPatterns = (ReluPattern0, ReluPattern1)
ConvPatterns = (Conv1dPattern, Conv2dPattern)
ConvReluPatterns = (
Conv1dReluPattern0,
Conv1dReluPattern1,
Conv2dReluPattern0,
Conv2dReluPattern1,
)
def get_args_and_kwargs_add(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
X_scale = dequants_inputs[0].args[1]
X_zero_point = dequants_inputs[0].args[2]
Y_scale = dequants_inputs[1].args[1]
Y_zero_point = dequants_inputs[1].args[2]
args = (
inputs_inputs[0],
X_scale,
X_zero_point,
inputs_inputs[1],
Y_scale,
Y_zero_point,
quant_node.args[1],
quant_node.args[2],
)
kwargs = {}
return args, kwargs
# Helper function to get the args and kwargs for the linear replacement op
def get_args_and_kwargs_linear(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the linear replacement op.
"""
weight_scale = dequants_weights[0].args[1]
# pyre-fixme[58]: Unsupported operand types
bias_scale = dequants_inputs[0].args[1] * weight_scale
requantize_scale = bias_scale / quant_node.args[1]
requantize_scale_t = torch.tensor([requantize_scale])
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
# If bias is not available, create a bias tensor with the shape of weight[0]
if not bias_inputs:
weight_node = dequants_weights[0].args[0]
assert isinstance(weight_node, fx.Node)
bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
else:
bias = bias_inputs[0]
args = tuple(inputs_inputs + weights_inputs + [bias])
kwargs = {
"src_zero_point": dequants_inputs[0].args[2],
"weight_zero_point": dequants_weights[0].args[2],
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
"out_zero_point": quant_node.args[2],
"offset": None,
}
return args, kwargs
# Helper function to get the args and kwargs for the layer norm replacement op
def get_args_and_kwargs_layer_norm(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
other_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the layer norm replacement op.
"""
# Check if the input is per-channel quantized
# TODO(matthiascremon): add proper support and testing for per-channel quantization
assert isinstance(dequants_inputs[0].args[1], float) and isinstance(
dequants_inputs[0].args[2], int
), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
# Make the scale and zero_point tensors
scale = dequants_inputs[0].args[1]
zero_point = dequants_inputs[0].args[2]
weight = other_inputs[1] if len(other_inputs) > 1 else None
if not weight:
weight = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
other_inputs[0],
1,
),
{"dtype": torch.float32},
)
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for layer norm weight, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
weight.meta["val"] = torch.full(other_inputs[0], 1, dtype=torch.float32)
copy_node_metadata(weight, inputs_inputs[0])
bias = other_inputs[2] if len(other_inputs) > 2 else None
if not bias:
bias = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
other_inputs[0],
0,
),
{"dtype": torch.float32},
)
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for layer norm bias, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
bias.meta["val"] = torch.full(other_inputs[0], 0, dtype=torch.float32)
copy_node_metadata(bias, inputs_inputs[0])
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + [scale, zero_point])
kwargs = {
"normalized_shape": other_inputs[0],
"weight": weight,
"bias": bias,
"eps": 1e-05,
"output_scale": quant_node.args[1],
"output_zero_point": quant_node.args[2],
}
return args, kwargs
def get_args_and_kwargs_matmul(
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
requantize_scale = (
# pyre-ignore[58]: Unsupported operand
dequants_inputs[0].args[1]
* dequants_inputs[1].args[1]
) / quant_node.args[1]
requantize_scale_t = torch.tensor([requantize_scale])
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
args = (
inputs_inputs[0],
dequants_inputs[0].args[2],
inputs_inputs[1],
dequants_inputs[1].args[2],
None,
)
kwargs = {
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
"out_zero_point": quant_node.args[2],
"transposed": False,
}
return args, kwargs
def get_args_and_kwargs_cat(
inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
args = tuple([inputs_inputs] + other_inputs)
dim = op_node.args[1] if len(op_node.args) > 1 else 0
# pyre-fixme[6]: Incompatible parameter type
kwargs = {"dim": int(dim)}
return args, kwargs
def get_args_and_kwargs_conv(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
quant_node: fx.Node,
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
weight_scale = dequants_weights[0].args[1]
weight_zero_point = dequants_weights[0].args[2]
# pyre-fixme[58]: Unsupported operand types
bias_scale = dequants_inputs[0].args[1] * weight_scale
stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1)
padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0)
dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1)
groups = 1 if len(op_node.args) < 7 else op_node.args[6]
# If bias is not available, create a bias tensor with the shape of weight[0]
if not bias_inputs:
weight_node = dequants_weights[0].args[0]
assert isinstance(weight_node, fx.Node)
bias = create_zero_bias_int32(graph_module, weight_node, bias_scale)
else:
bias = bias_inputs[0]
# Compute the out multiplier and out shift. They are used when the conv op is
# replaced by quantized linear, we compute them a priori for simplicity but
# may revisit the decision.
requantize_scale = bias_scale / quant_node.args[1]
requantize_scale_t = torch.tensor([requantize_scale])
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + weights_inputs + [bias])
kwargs = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"groups": groups,
"input_zero_point": dequants_inputs[0].args[2],
"weight_zero_point": weight_zero_point,
"bias_scale": bias_scale,
"out_scale": quant_node.args[1],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs
def get_args_and_kwargs_relu(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
input_scale = dequants_inputs[0].args[1]
# pyre-fixme[58]: Unsupported operand types
requantize_scale = input_scale / quant_node.args[1]
requantize_scale_t = torch.tensor([requantize_scale])
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs)
kwargs = {
"X_zero_point": dequants_inputs[0].args[2],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs
def get_args_and_kwargs_mixed_w8a32_linear(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
w_scale_ = dequants_weights[0].args[1]
b_scale_ = dequants_biases[0].args[1]
args = (
other_inputs[0],
weights_inputs[0],
w_scale_,
bias_inputs[0],
b_scale_,
)
kwargs = {}
return args, kwargs
def get_args_and_kwargs_softmax(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Make a dummy mask tensor
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
mask_shape = list(mask_shape) if mask_shape else []
mask_shape[-1] = mask_shape[-1] // 16
mask_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
mask_shape,
0.0,
),
{"dtype": torch.int32},
)
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for softmax, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
copy_node_metadata(mask_tensor, inputs_inputs[0])
# Make the scale and zero_point tensors
in_scale = dequants_inputs[0].args[1]
in_zero_point = dequants_inputs[0].args[2]
out_scale = quant_node.args[1]
out_zero_point = quant_node.args[2]
# Make the args and kwargs for the replacement op
args = (
inputs_inputs[0],
mask_tensor,
op_node.args[1],
in_scale,
in_zero_point,
out_scale,
out_zero_point,
)
kwargs = {}
return args, kwargs
def get_args_and_kwargs_mixed_w8a32_conv(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet
if len(op_node.args) > 3:
assert op_node.args[3] == [1] # Stride
if len(op_node.args) > 4:
assert op_node.args[4] == [0] # Padding
if len(op_node.args) > 5:
assert op_node.args[5] == [1] # Dilation
if len(op_node.args) > 6:
assert op_node.args[6] == 1 # Groups
assert len(dequants_weights) == 1
assert len(dequants_biases) == 1
W_scale_ = dequants_weights[0].args[1]
B_scale_ = dequants_biases[0].args[1]
transposed_inputs = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
)
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
original_val = other_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on input node"
with original_val.fake_mode:
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
original_val, [0, 2, 1]
)
copy_node_metadata(transposed_inputs, other_inputs[0])
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
)
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
original_val = weights_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
original_val, [2, 0, 1]
)
copy_node_metadata(transposed_weights, weights_inputs[0])
args = (
transposed_inputs,
transposed_weights,
W_scale_,
bias_inputs[0],
B_scale_,
)
kwargs = {}
return args, kwargs
def get_args_and_kwargs_max_pool2d(
inputs_inputs: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the max_pool2d replacement op.
Max pooling is order-preserving, so we can perform the max operation
directly on quantized values without any requantization.
"""
# Get the pooling parameters from the original op node using get_arg
kernel_size = get_arg(op_node, "kernel_size", Optional[list[int]]) or [1, 1]
stride = get_arg(op_node, "stride", Optional[list[int]]) or kernel_size
padding = get_arg(op_node, "padding", Optional[list[int]]) or [0, 0]
dilation = get_arg(op_node, "dilation", Optional[list[int]]) or [1, 1]
ceil_mode = get_arg(op_node, "ceil_mode", Optional[bool]) or False
args = (inputs_inputs[0],)
kwargs = {
"kernel_size": kernel_size,
"stride": stride,
"padding": padding,
"dilation": dilation,
"ceil_mode": ceil_mode,
}
return args, kwargs
def get_args_and_kwargs_mixed_w8a32_gru(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet
assert len(dequants_weights) == 2
assert len(dequants_biases) == 2
w_i_scale = dequants_weights[0].args[1]
w_h_scale = dequants_weights[1].args[1]
b_i_scale = dequants_biases[0].args[1]
b_h_scale = dequants_biases[1].args[1]
args = (
other_inputs[0],
other_inputs[1],
weights_inputs[0],
w_i_scale,
weights_inputs[1],
w_h_scale,
bias_inputs[0],
b_i_scale,
bias_inputs[1],
b_h_scale,
)
kwargs = {}
return args, kwargs
class QuantFusion(ExportPass):
# pyre-ignore[2]: Parameter `patterns` has no type specified
def __init__(self, patterns) -> None:
super().__init__()
# pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified
self.patterns = patterns
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
for pattern in self.patterns:
fused_partitions = find_sequential_partitions_aten(
graph_module,
pattern.partition_types(),
)
for fused_partition in fused_partitions:
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
if not anchors or anchors.empty:
continue
if any(self.is_fused(p.nodes) for p in fused_partition):
continue
for p in fused_partition:
self.mark_fused(p.nodes)
dequants_inputs = []
for node, idx, *_spec in anchors.inputs:
arg = (
node.args[idx]
if isinstance(idx, int)
else node.args[idx[0]][idx[1]]
)
if (
arg.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
dequants_inputs.append(arg)
dequants_weights = []
for node, idx in anchors.weights:
if (
node.args[idx].target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
dequants_weights.append(node.args[idx])
dequants_biases = []
for node, idx, *_spec in anchors.biases:
if (
node.args[idx].target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
dequants_biases.append(node.args[idx])
inputs_inputs = [node.args[0] for node in dequants_inputs]
weights_inputs = [node.args[0] for node in dequants_weights]
bias_inputs = [node.args[0] for node in dequants_biases]
other_inputs = [node.args[idx] for node, idx in anchors.others]
assert op_node is not None, "op_node is None"
quant_node = list(op_node.users.keys())[0]
# For ops that return tuples (e.g., max_pool2d_with_indices),
# traverse through the getitem to find the actual quant node
if quant_node.target is op_module.getitem:
assert (
len(quant_node.args) >= 2 and quant_node.args[1] == 0
), f"Expected getitem[0] for the values output, but got getitem[{quant_node.args[1] if len(quant_node.args) >= 2 else '?'}]"
assert (
len(list(quant_node.users.keys())) > 0
), "getitem node has no users"
quant_node = list(quant_node.users.keys())[0]
with graph_module.graph.inserting_after(op_node):
args = tuple(
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
kwargs = {}
if isinstance(pattern, AddPattern):
args, kwargs = get_args_and_kwargs_add(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
)
elif isinstance(pattern, CatPattern):
# Skip fusion if inputs_inputs is empty to avoid creating cat([])
if not inputs_inputs:
continue
args, kwargs = get_args_and_kwargs_cat(
inputs_inputs, other_inputs, op_node
)
elif isinstance(pattern, ConvReluPatterns):
# For ConvReLU, we are fusing Conv+ReLU
# This means that the op we want to get
# the replacement args and kwargs for is the
# *conv* op, which is the anchor input, NOT
# the anchor output (which is the ReLU)
check_out_zero_point_is_min_range(
quant_node.args[2], quant_node.args[5]
)
anchor_input_node = anchors.inputs[0][0]
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
dequants_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
quant_node,
anchor_input_node,
)
elif isinstance(pattern, ConvPatterns):
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
dequants_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
quant_node,
op_node,
)
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
graph_module,
inputs_inputs,
dequants_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
quant_node,
)
elif isinstance(pattern, LayerNormPattern):
args, kwargs = get_args_and_kwargs_layer_norm(
graph_module,
inputs_inputs,
dequants_inputs,
other_inputs,
quant_node,
)
elif isinstance(pattern, (BmmPattern, MatmulPattern)):
args, kwargs = get_args_and_kwargs_matmul(
inputs_inputs,
dequants_inputs,
quant_node,
)
elif isinstance(pattern, AddmmPattern):
# Transpose the weight tensor
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.transpose.int,
(weights_inputs[0], 0, 1),
)
assert (
"val" in weights_inputs[0].meta
), "Missing val metadata on weight node"
original_val = weights_inputs[0].meta["val"]
assert (
original_val.fake_mode is not None
), "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = (
torch.ops.aten.transpose.int(original_val, 0, 1)
)
copy_node_metadata(transposed_weights, weights_inputs[0])
# Call linear with transposed weight
args, kwargs = get_args_and_kwargs_linear(
graph_module,
inputs_inputs,
dequants_inputs,
[transposed_weights],
dequants_weights,
bias_inputs,
quant_node,
)
elif isinstance(pattern, ReluPatterns):
args, kwargs = get_args_and_kwargs_relu(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
)
elif isinstance(pattern, SoftmaxPattern):
args, kwargs = get_args_and_kwargs_softmax(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
op_node,
)
elif isinstance(pattern, MixedW8A32LinearPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_linear(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
)
elif isinstance(pattern, MixedW8A32ConvPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_conv(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
op_node,
)
elif isinstance(pattern, MixedW8A32GruPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_gru(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
op_node,
)
elif isinstance(
pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)
):
args, kwargs = get_args_and_kwargs_max_pool2d(
inputs_inputs,
op_node,
)
fused = graph_module.graph.call_function(
pattern.replacement_op(),
args,
kwargs,
)
if len(anchors.output) > 0:
fused.meta = quant_node.meta
quant_node.replace_all_uses_with(fused)
else:
fused.meta = op_node.meta
op_node.replace_all_uses_with(fused)
if op_node.op == "output":
_ = graph_module.graph.output((fused,))
legalize_graph(graph_module)
graph_module.graph.eliminate_dead_code()
nodes_list = list(graph_module.graph.nodes)
if len(nodes_list) > 0 and nodes_list[-1].op != "output":
output_nodes = [n for n in nodes_list if n.op == "output"]
output_arg = output_nodes[0].args[0]
original_meta = output_nodes[0].meta.copy()
for out_node in output_nodes:
graph_module.graph.erase_node(out_node)
new_output_node = graph_module.graph.output(output_arg)
new_output_node.meta.update(original_meta)
graph_module.recompile()
return PassResult(graph_module, True)
@classmethod
# pyre-ignore[2]: Parameter `nodes` has no type specified
def is_fused(cls, nodes) -> bool:
return any(cls.__qualname__ in n.meta for n in nodes)
@classmethod
# pyre-ignore[2]: Parameter `nodes` has no type specified
def mark_fused(cls, nodes) -> bool:
for n in nodes:
# pyre-fixme[7]: Incompatible return type
n.meta["QuantFusion"] = True