Skip to content

Commit 11075ee

Browse files
justinchubybmehta001
authored andcommitted
Fix pytest for TestCosSinCacheTransform (microsoft#2358)
With the latest version of pytest we get `onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py::test_case_1 - Failed: Expected None, but test returned <onnxscript.rewriter.ort_fusions.models._rotary_embedding_models._TestCase1 object at 0x117c920d0>. Did you mean to use `assert` instead of `return`?` This is because the imported functions `test_case_1` and `test_case_2` are not really test cases but were treated as such by pytest. This PR hides them from the test module so they are not triggered.
1 parent 3fd79be commit 11075ee

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import onnxscript.optimizer
1010
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1111
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
12-
from onnxscript.rewriter.ort_fusions.models._rotary_embedding_models import (
13-
partial_rotary_test_case,
14-
test_case_1,
15-
test_case_2,
16-
)
17-
from onnxscript.rewriter.ort_fusions.models._smollm_1 import smollm_test_1
12+
from onnxscript.rewriter.ort_fusions.models import _rotary_embedding_models, _smollm_1
1813
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
1914
fuse_partial_rotary_embedding,
2015
fuse_rotary_embedding,
@@ -26,19 +21,19 @@ class TestCosSinCacheTransform(unittest.TestCase):
2621
[
2722
(
2823
"smollm_test_1",
29-
smollm_test_1,
24+
_smollm_1.smollm_test_1,
3025
),
3126
(
3227
"test_case_1",
33-
test_case_1,
28+
_rotary_embedding_models.test_case_1,
3429
),
3530
(
3631
"test_case_2",
37-
test_case_2,
32+
_rotary_embedding_models.test_case_2,
3833
),
3934
(
4035
"partial_rotary_test_case",
41-
partial_rotary_test_case,
36+
_rotary_embedding_models.partial_rotary_test_case,
4237
),
4338
]
4439
)
@@ -56,7 +51,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
5651
assert_allclose(new_outputs, original_outputs)
5752

5853
def test_partial_rotary_fusion(self):
59-
test = partial_rotary_test_case()
54+
test = _rotary_embedding_models.partial_rotary_test_case()
6055
model = test.get_onnx_model()
6156
onnxscript.optimizer.optimize(model)
6257
inputs = test.get_ort_inputs()

onnxscript/rewriter/ort_fusions/rotary_embedding_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,28 @@
77
from parameterized import parameterized
88

99
import onnxscript.optimizer
10-
from onnxscript.rewriter.ort_fusions.models._rotary_embedding_models import test_case_1
11-
from onnxscript.rewriter.ort_fusions.models._smollm_1 import smollm_test_1
12-
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
10+
from onnxscript.rewriter.ort_fusions import rotary_embedding
11+
from onnxscript.rewriter.ort_fusions.models import _rotary_embedding_models, _smollm_1
1312

1413

1514
class TestRotaryEmbedding(unittest.TestCase):
1615
@parameterized.expand(
1716
[
1817
(
1918
"test_case_1",
20-
test_case_1,
19+
_rotary_embedding_models.test_case_1,
2120
),
2221
(
2322
"smollm_test_1",
24-
smollm_test_1,
23+
_smollm_1.smollm_test_1,
2524
),
2625
]
2726
)
28-
def test_rotary_embedding_fusion(self, name, test_data_constructor):
27+
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
2928
test = test_data_constructor()
3029
model = test.get_onnx_model()
3130
onnxscript.optimizer.optimize(model)
32-
fuse_rotary_embedding(model)
31+
rotary_embedding.fuse_rotary_embedding(model)
3332
op_types = [n.op_type for n in model.graph]
3433
self.assertIn("RotaryEmbedding", op_types)
3534

0 commit comments

Comments
 (0)