Skip to content

Commit c626f55

Browse files
committed
adjust the test
1 parent 832f3dc commit c626f55

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/onnx_ir/passes/common/common_subexpression_elimination_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,18 +356,18 @@ def test_the_constant_nodes_with_the_tensors_larger_than_size_limit_are_not_csed
356356
"""Test if the constant nodes with the tensors larger than size limit are not CSEd.
357357
358358
def f(x):
359-
a = x + [1, 2]
360-
b = x + [1, 2]
359+
a = x + [1, 2, 3, 4]
360+
b = x + [1, 2, 3, 4]
361361
return a + b
362362
"""
363363

364364
@script()
365-
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
366-
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
367-
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
365+
def test_model(x: FLOAT[4, 4]) -> FLOAT[4, 4]:
366+
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
367+
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
368368
return a + b
369369

370370
model_proto = test_model.to_model_proto()
371371
model = ir.serde.deserialize_model(model_proto)
372372
# Add and Constant nodes should not be CSEd
373-
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0], size_limit=4)
373+
self.check_graph(model, [np.random.rand(4, 4)], delta_nodes=[0], size_limit=3)

0 commit comments

Comments
 (0)