File tree 2 files changed +0
-24
lines changed
2 files changed +0
-24
lines changed Original file line number Diff line number Diff line change 18
18
)
19
19
from executorch .backends .cadence .aot .quantizer .fusion_pass import QuantFusion
20
20
from executorch .backends .cadence .aot .quantizer .quantizer import CadenceQuantizer
21
-
22
- from executorch .backends .cadence .aot .replace_ops import ReplaceSafeSoftmaxWithSoftmax
23
21
from executorch .backends .cadence .aot .utils import (
24
22
get_default_memory_config ,
25
23
MemoryConfig ,
26
- model_gm_has_SDPA ,
27
24
model_is_quantized ,
28
25
)
29
- from executorch .backends .transforms .decompose_sdpa import (
30
- DecomposeScaledDotProductAttention ,
31
- )
32
26
from executorch .devtools import generate_etrecord
33
27
from executorch .exir import (
34
28
EdgeCompileConfig ,
@@ -91,16 +85,6 @@ def convert_pt2(
91
85
.module ()
92
86
)
93
87
94
- if model_gm_has_SDPA (model_gm ):
95
- # Decompose SDPA
96
- DecomposeScaledDotProductAttention (False )(model_gm )
97
-
98
- # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
99
- # for details).
100
- result = ReplaceSafeSoftmaxWithSoftmax ()(model_gm )
101
- assert result is not None
102
- model_gm = result .graph_module
103
-
104
88
# Prepare
105
89
prepared_model = prepare_pt2e (model_gm , quantizer )
106
90
Original file line number Diff line number Diff line change @@ -235,14 +235,6 @@ def print_ops_info(
235
235
)
236
236
237
237
238
- def model_gm_has_SDPA (model_gm : torch .fx .GraphModule ) -> bool :
239
- for node in model_gm .graph .nodes :
240
- if node .op == "call_function" :
241
- if node .target == torch .ops .aten .scaled_dot_product_attention .default :
242
- return True
243
- return False
244
-
245
-
246
238
def save_pte_program (
247
239
prog : ExecutorchProgramManager , model_name : str , output_dir : str = ""
248
240
) -> None :
You can’t perform that action at this time.
0 commit comments