|
9 | 9 |
|
10 | 10 | import torch |
11 | 11 | from executorch import exir |
| 12 | +from executorch.exir import to_edge |
12 | 13 | from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( |
13 | 14 | generate_pattern_op_partitions, |
14 | 15 | ) |
|
20 | 21 | ) |
21 | 22 | from executorch.exir.backend.test.demo_backend import DemoBackend |
22 | 23 | from executorch.exir.backend.utils import tag_constant_data |
23 | | -from torch.export import ExportedProgram |
| 24 | +from torch.export import export, ExportedProgram |
24 | 25 | from torch.fx.passes.infra.partitioner import Partition |
25 | 26 |
|
26 | 27 |
|
@@ -63,56 +64,30 @@ def forward(self, x_raw, h, c): |
63 | 64 | input_h = torch.ones([1, 32]) |
64 | 65 | input_c = torch.ones([1, 32]) |
65 | 66 |
|
66 | | - pattern_lstm_conv_lifted = ( |
67 | | - exir.capture( |
68 | | - LSTMConvPattern(), |
69 | | - (input_x, input_h, input_c), |
70 | | - exir.CaptureConfig(enable_aot=True), |
71 | | - ) |
72 | | - .to_edge( |
73 | | - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. |
74 | | - exir.EdgeCompileConfig(_check_ir_validity=False) |
75 | | - ) |
76 | | - .exported_program.graph_module |
77 | | - ) |
78 | 67 | pattern_lstm_conv = ( |
79 | | - exir.capture( |
80 | | - LSTMConvPattern(), |
81 | | - (input_x, input_h, input_c), |
82 | | - exir.CaptureConfig(), |
83 | | - ) |
84 | | - .to_edge( |
| 68 | + to_edge( |
| 69 | + export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True), |
85 | 70 | # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. |
86 | | - exir.EdgeCompileConfig(_check_ir_validity=False) |
| 71 | + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), |
87 | 72 | ) |
88 | | - .exported_program.graph_module |
| 73 | + .exported_program() |
| 74 | + .graph_module |
89 | 75 | ) |
90 | 76 |
|
91 | | - def sub(x, y): |
92 | | - return torch.sub(x, y) |
| 77 | + class SubModule(torch.nn.Module): |
| 78 | + def forward(self, x, y): |
| 79 | + return torch.sub(x, y) |
93 | 80 |
|
94 | | - pattern_sub_lifted = ( |
95 | | - exir.capture( |
96 | | - sub, |
97 | | - (input_x, input_h), |
98 | | - exir.CaptureConfig(enable_aot=True, _unlift=False), |
99 | | - ) |
100 | | - .to_edge(exir.EdgeCompileConfig(_use_edge_ops=True)) |
101 | | - .exported_program.graph_module |
102 | | - ) |
103 | 81 | pattern_sub = ( |
104 | | - exir.capture( |
105 | | - sub, |
106 | | - (input_x, input_h), |
107 | | - exir.CaptureConfig(), |
| 82 | + to_edge( |
| 83 | + export(SubModule(), (input_x, input_h), strict=True), |
| 84 | + compile_config=exir.EdgeCompileConfig(_use_edge_ops=True), |
108 | 85 | ) |
109 | | - .to_edge() |
110 | | - .exported_program.graph_module |
| 86 | + .exported_program() |
| 87 | + .graph_module |
111 | 88 | ) |
112 | 89 | self.patterns = [ |
113 | | - pattern_lstm_conv_lifted.graph, |
114 | 90 | pattern_lstm_conv.graph, |
115 | | - pattern_sub_lifted.graph, |
116 | 91 | pattern_sub.graph, |
117 | 92 | ] |
118 | 93 |
|
@@ -239,33 +214,17 @@ def forward(self, x_raw, h, c): |
239 | 214 | input_h = torch.ones([1, 32]) |
240 | 215 | input_c = torch.ones([1, 32]) |
241 | 216 |
|
242 | | - pattern_lstm_conv_lifted = ( |
243 | | - exir.capture( |
244 | | - LSTMConvPattern(), |
245 | | - (input_x, input_h, input_c), |
246 | | - exir.CaptureConfig(enable_aot=True), |
247 | | - ) |
248 | | - .to_edge( |
249 | | - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. |
250 | | - exir.EdgeCompileConfig(_check_ir_validity=False) |
251 | | - ) |
252 | | - .exported_program.graph_module |
253 | | - ) |
254 | | - pattern_lstm_conv_unlifted = ( |
255 | | - exir.capture( |
256 | | - LSTMConvPattern(), |
257 | | - (input_x, input_h, input_c), |
258 | | - exir.CaptureConfig(), |
259 | | - ) |
260 | | - .to_edge( |
| 217 | + pattern_lstm_conv = ( |
| 218 | + to_edge( |
| 219 | + export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True), |
261 | 220 | # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. |
262 | | - exir.EdgeCompileConfig(_check_ir_validity=False) |
| 221 | + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), |
263 | 222 | ) |
264 | | - .exported_program.graph_module |
| 223 | + .exported_program() |
| 224 | + .graph_module |
265 | 225 | ) |
266 | 226 | self.patterns = [ |
267 | | - pattern_lstm_conv_lifted.graph, |
268 | | - pattern_lstm_conv_unlifted.graph, |
| 227 | + pattern_lstm_conv.graph, |
269 | 228 | ] |
270 | 229 | # Only (lstm + conv) pattern is lowerable |
271 | 230 |
|
|
0 commit comments