Skip to content

Commit 3aefd56

Browse files
guangy10facebook-github-bot
authored andcommitted
Fix export tests (#2693)
Summary: Pull Request resolved: #2693 - Fix broken/flaky tests Reviewed By: mergennachin, kirklandsign Differential Revision: D55382428 fbshipit-source-id: f64d0fda063b42b86f2e77431790708d3cc2c512
1 parent 253f2fa commit 3aefd56

File tree

1 file changed

+60
-57
lines changed

1 file changed

+60
-57
lines changed

examples/portable/test/test_export.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,27 @@
66

77
import unittest
88

9-
from typing import Any, Callable
10-
119
import torch
1210
from executorch.examples.models import MODEL_NAME_TO_MODEL
1311
from executorch.examples.models.model_factory import EagerModelFactory
1412

15-
from executorch.examples.portable.utils import export_to_edge
16-
1713
from executorch.extension.pybindings.portable_lib import ( # @manual
1814
_load_for_executorch_from_buffer,
1915
)
2016

17+
from ..utils import export_to_edge
18+
2119

2220
class ExportTest(unittest.TestCase):
23-
def _assert_eager_lowered_same_result(
21+
def collect_executorch_and_eager_outputs(
2422
self,
2523
eager_model: torch.nn.Module,
2624
example_inputs,
27-
validation_fn: Callable[[Any, Any], bool],
2825
):
2926
"""
30-
Asserts that the given model has the same result as the eager mode
31-
lowered model, with example_inputs, validated by validation_fn, which
32-
takes the eager mode output and ET output, and returns True if they
33-
match.
27+
Compares the output of the given eager mode PyTorch model with the output
28+
of the equivalent executorch model, both provided with example inputs.
29+
Returns a tuple containing the outputs of the eager mode model and the executorch mode model.
3430
"""
3531
eager_model = eager_model.eval()
3632
model = torch._export.capture_pre_autograd_graph(eager_model, example_inputs)
@@ -45,100 +41,107 @@ def _assert_eager_lowered_same_result(
4541
with torch.no_grad():
4642
executorch_output = pte_model.run_method("forward", example_inputs)
4743

48-
self.assertTrue(validation_fn(eager_output, executorch_output))
44+
return (eager_output, executorch_output)
4945

50-
@staticmethod
51-
def validate_tensor_allclose(eager_output, executorch_output, rtol=1e-5, atol=1e-5):
52-
result = torch.allclose(
53-
eager_output,
54-
executorch_output[0],
55-
rtol=rtol,
56-
atol=atol,
57-
)
58-
if not result:
59-
print(f"eager output: {eager_output}")
60-
print(f"executorch output: {executorch_output}")
61-
return result
46+
def validate_tensor_allclose(
47+
self, eager_output, executorch_output, rtol=1e-5, atol=1e-5
48+
):
49+
self.assertTrue(
50+
isinstance(eager_output, type(executorch_output)),
51+
f"Outputs are not of the same type: eager type: {type(eager_output)}, executorch type: {type(executorch_output)}",
52+
)
53+
self.assertTrue(
54+
len(eager_output) == len(executorch_output),
55+
f"len(eager_output)={len(eager_output)}, len(executorch_output)={len(executorch_output)}",
56+
)
57+
result = True
58+
for i in range(len(eager_output)):
59+
result = torch.allclose(
60+
eager_output[i],
61+
executorch_output[i],
62+
rtol=rtol,
63+
atol=atol,
64+
)
65+
if not result:
66+
print(f"eager output[{i}]: {eager_output[i]}")
67+
print(f"executorch output[{i}]: {executorch_output[i]}")
68+
break
69+
return self.assertTrue(result)
6270

6371
def test_mv3_export_to_executorch(self):
6472
eager_model, example_inputs, _ = EagerModelFactory.create_model(
6573
*MODEL_NAME_TO_MODEL["mv3"]
6674
)
67-
eager_model = eager_model.eval()
68-
75+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
76+
eager_model, example_inputs
77+
)
6978
# TODO(T166083470): Fix accuracy issue
70-
self._assert_eager_lowered_same_result(
71-
eager_model,
72-
example_inputs,
73-
lambda x, y: self.validate_tensor_allclose(x, y, rtol=1e-3, atol=1e-5),
79+
self.validate_tensor_allclose(
80+
eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
7481
)
7582

7683
def test_mv2_export_to_executorch(self):
7784
eager_model, example_inputs, _ = EagerModelFactory.create_model(
7885
*MODEL_NAME_TO_MODEL["mv2"]
7986
)
80-
eager_model = eager_model.eval()
81-
82-
self._assert_eager_lowered_same_result(
83-
eager_model, example_inputs, self.validate_tensor_allclose
87+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
88+
eager_model, example_inputs
8489
)
90+
self.validate_tensor_allclose(eager_output, executorch_output[0])
8591

8692
def test_vit_export_to_executorch(self):
8793
eager_model, example_inputs, _ = EagerModelFactory.create_model(
8894
*MODEL_NAME_TO_MODEL["vit"]
8995
)
90-
eager_model = eager_model.eval()
91-
92-
self._assert_eager_lowered_same_result(
93-
eager_model, example_inputs, self.validate_tensor_allclose
96+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
97+
eager_model, example_inputs
9498
)
99+
self.validate_tensor_allclose(eager_output, executorch_output[0])
95100

96101
def test_w2l_export_to_executorch(self):
97102
eager_model, example_inputs, _ = EagerModelFactory.create_model(
98103
*MODEL_NAME_TO_MODEL["w2l"]
99104
)
100-
eager_model = eager_model.eval()
101-
102-
self._assert_eager_lowered_same_result(
103-
eager_model, example_inputs, self.validate_tensor_allclose
105+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
106+
eager_model, example_inputs
104107
)
108+
self.validate_tensor_allclose(eager_output, executorch_output[0])
105109

106110
def test_ic3_export_to_executorch(self):
107111
eager_model, example_inputs, _ = EagerModelFactory.create_model(
108112
*MODEL_NAME_TO_MODEL["ic3"]
109113
)
110-
eager_model = eager_model.eval()
111-
112-
self._assert_eager_lowered_same_result(
113-
eager_model, example_inputs, self.validate_tensor_allclose
114+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
115+
eager_model, example_inputs
116+
)
117+
# TODO(T166083470): Fix accuracy issue
118+
self.validate_tensor_allclose(
119+
eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
114120
)
115121

116122
def test_resnet18_export_to_executorch(self):
117123
eager_model, example_inputs, _ = EagerModelFactory.create_model(
118124
*MODEL_NAME_TO_MODEL["resnet18"]
119125
)
120-
eager_model = eager_model.eval()
121-
122-
self._assert_eager_lowered_same_result(
123-
eager_model, example_inputs, self.validate_tensor_allclose
126+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
127+
eager_model, example_inputs
124128
)
129+
self.validate_tensor_allclose(eager_output, executorch_output[0])
125130

126131
def test_resnet50_export_to_executorch(self):
127132
eager_model, example_inputs, _ = EagerModelFactory.create_model(
128133
*MODEL_NAME_TO_MODEL["resnet50"]
129134
)
130-
eager_model = eager_model.eval()
131-
132-
self._assert_eager_lowered_same_result(
133-
eager_model, example_inputs, self.validate_tensor_allclose
135+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
136+
eager_model, example_inputs
134137
)
138+
self.validate_tensor_allclose(eager_output, executorch_output[0])
135139

136140
def test_dl3_export_to_executorch(self):
137141
eager_model, example_inputs, _ = EagerModelFactory.create_model(
138142
*MODEL_NAME_TO_MODEL["dl3"]
139143
)
140-
eager_model = eager_model.eval()
141-
142-
self._assert_eager_lowered_same_result(
143-
eager_model, example_inputs, self.validate_tensor_allclose
144+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
145+
eager_model, example_inputs
144146
)
147+
self.validate_tensor_allclose(list(eager_output.values()), executorch_output)

0 commit comments

Comments
 (0)