Skip to content

Improve redundant slice removal #2441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ def _identity_to_itself(op, data, **_):

def _potential_redundant_slice(op, data, starts, ends, axes, steps):
"""To identify a slice op"""
return op.Slice(data, starts, ends, axes, steps)
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])


def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_):
"""Check if the shape of the slice output is the same as the data."""
if data.shape is None or slice_output.shape is None:
return False
return data.shape == slice_output.shape


# Register the rewrite rules
Expand All @@ -83,5 +90,13 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps):
_check_if_redundant_slice,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = RewriteRuleSet([remove_redundant_slice])
remove_redundant_slice2 = RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_same_shape,
)

# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one,
# provided shape-inference is run before the rewriter and computes the shape of the slice output.

rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2])
22 changes: 20 additions & 2 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self):
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
def test_slice_unequal_dynamic_shape(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[L, M, N] data) => (float[L, M, N] output)
agraph (float[L, M, N] data) => (float[P, M, N] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{9}}}>()
Expand All @@ -82,3 +82,21 @@ def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_slice_equal_dynamic_shape(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[L, M, N] data) => (float[L, M, N] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{9}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
Loading