Skip to content

Commit 96aeb57

Browse files
Migrate executorch/ tests from exir.capture to torch.export + to_edge
Differential Revision: D95605454 Pull Request resolved: #18111
1 parent 6381e58 commit 96aeb57

File tree

9 files changed

+120
-1645
lines changed

9 files changed

+120
-1645
lines changed

exir/backend/test/BUCK

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -158,38 +158,6 @@ fbcode_target(_kind = runtime.python_library,
158158
],
159159
)
160160

161-
fbcode_target(_kind = runtime.python_test,
162-
name = "test_backends",
163-
srcs = [
164-
"test_backends.py",
165-
],
166-
preload_deps = [
167-
"//executorch/configurations:optimized_native_cpu_ops",
168-
"//executorch/kernels/quantized:custom_ops_generated_lib",
169-
"//executorch/runtime/executor/test:test_backend_compiler_lib",
170-
],
171-
deps = [
172-
":backend_with_compiler_demo",
173-
":hta_partitioner_demo",
174-
":op_partitioner_demo",
175-
":demo_backend",
176-
"//caffe2:torch",
177-
"//caffe2/functorch:functorch_src",
178-
"//executorch/exir:delegate",
179-
"//executorch/exir:graph_module",
180-
"//executorch/exir:lib",
181-
"//executorch/exir:lowered_backend_module",
182-
"//executorch/exir:print_program",
183-
"//executorch/exir:schema",
184-
"//executorch/exir/backend:backend_api",
185-
"//executorch/exir/backend:compile_spec_schema",
186-
"//executorch/exir/backend:partitioner",
187-
"//executorch/exir/dialects:lib",
188-
"//executorch/extension/pybindings:portable_lib", # @manual
189-
"//executorch/extension/pytree:pylib",
190-
],
191-
)
192-
193161
fbcode_target(_kind = runtime.python_test,
194162
name = "test_to_backend_multi_method",
195163
srcs = [

exir/backend/test/hta_partitioner_demo.py

Lines changed: 22 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from executorch import exir
12+
from executorch.exir import to_edge
1213
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1314
generate_pattern_op_partitions,
1415
)
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.backend.test.demo_backend import DemoBackend
2223
from executorch.exir.backend.utils import tag_constant_data
23-
from torch.export import ExportedProgram
24+
from torch.export import export, ExportedProgram
2425
from torch.fx.passes.infra.partitioner import Partition
2526

2627

@@ -63,56 +64,30 @@ def forward(self, x_raw, h, c):
6364
input_h = torch.ones([1, 32])
6465
input_c = torch.ones([1, 32])
6566

66-
pattern_lstm_conv_lifted = (
67-
exir.capture(
68-
LSTMConvPattern(),
69-
(input_x, input_h, input_c),
70-
exir.CaptureConfig(enable_aot=True),
71-
)
72-
.to_edge(
73-
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
74-
exir.EdgeCompileConfig(_check_ir_validity=False)
75-
)
76-
.exported_program.graph_module
77-
)
7867
pattern_lstm_conv = (
79-
exir.capture(
80-
LSTMConvPattern(),
81-
(input_x, input_h, input_c),
82-
exir.CaptureConfig(),
83-
)
84-
.to_edge(
68+
to_edge(
69+
export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True),
8570
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
86-
exir.EdgeCompileConfig(_check_ir_validity=False)
71+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
8772
)
88-
.exported_program.graph_module
73+
.exported_program()
74+
.graph_module
8975
)
9076

91-
def sub(x, y):
92-
return torch.sub(x, y)
77+
class SubModule(torch.nn.Module):
78+
def forward(self, x, y):
79+
return torch.sub(x, y)
9380

94-
pattern_sub_lifted = (
95-
exir.capture(
96-
sub,
97-
(input_x, input_h),
98-
exir.CaptureConfig(enable_aot=True, _unlift=False),
99-
)
100-
.to_edge(exir.EdgeCompileConfig(_use_edge_ops=True))
101-
.exported_program.graph_module
102-
)
10381
pattern_sub = (
104-
exir.capture(
105-
sub,
106-
(input_x, input_h),
107-
exir.CaptureConfig(),
82+
to_edge(
83+
export(SubModule(), (input_x, input_h), strict=True),
84+
compile_config=exir.EdgeCompileConfig(_use_edge_ops=True),
10885
)
109-
.to_edge()
110-
.exported_program.graph_module
86+
.exported_program()
87+
.graph_module
11188
)
11289
self.patterns = [
113-
pattern_lstm_conv_lifted.graph,
11490
pattern_lstm_conv.graph,
115-
pattern_sub_lifted.graph,
11691
pattern_sub.graph,
11792
]
11893

@@ -239,33 +214,17 @@ def forward(self, x_raw, h, c):
239214
input_h = torch.ones([1, 32])
240215
input_c = torch.ones([1, 32])
241216

242-
pattern_lstm_conv_lifted = (
243-
exir.capture(
244-
LSTMConvPattern(),
245-
(input_x, input_h, input_c),
246-
exir.CaptureConfig(enable_aot=True),
247-
)
248-
.to_edge(
249-
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
250-
exir.EdgeCompileConfig(_check_ir_validity=False)
251-
)
252-
.exported_program.graph_module
253-
)
254-
pattern_lstm_conv_unlifted = (
255-
exir.capture(
256-
LSTMConvPattern(),
257-
(input_x, input_h, input_c),
258-
exir.CaptureConfig(),
259-
)
260-
.to_edge(
217+
pattern_lstm_conv = (
218+
to_edge(
219+
export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True),
261220
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
262-
exir.EdgeCompileConfig(_check_ir_validity=False)
221+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
263222
)
264-
.exported_program.graph_module
223+
.exported_program()
224+
.graph_module
265225
)
266226
self.patterns = [
267-
pattern_lstm_conv_lifted.graph,
268-
pattern_lstm_conv_unlifted.graph,
227+
pattern_lstm_conv.graph,
269228
]
270229
# Only (lstm + conv) pattern is lowerable
271230

0 commit comments

Comments
 (0)