Skip to content

[IR] Allow pass result as pass input #2220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxscript/ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def destructive(self) -> bool:
"""
return not self.in_place and self.changes_input

def __call__(self, model: ir.Model) -> PassResult:
def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult:
if isinstance(model_or_result, PassResult):
model = model_or_result.model
else:
model = model_or_result
# Check preconditions
try:
self.requires(model)
Expand Down
39 changes: 39 additions & 0 deletions onnxscript/ir/passes/_pass_infra_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

import unittest

from onnxscript import ir
from onnxscript.ir.passes import _pass_infra


class PassBaseTest(unittest.TestCase):
def test_pass_results_can_be_used_as_pass_input(self):
class TestPass(_pass_infra.PassBase):
@property
def in_place(self) -> bool:
return True

@property
def changes_input(self) -> bool:
return False

def call(self, model: ir.Model) -> _pass_infra.PassResult:
# This is a no-op pass
return _pass_infra.PassResult(model=model, modified=False)

pass_ = TestPass()
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
result = pass_(model)
self.assertIsInstance(result, _pass_infra.PassResult)
# pass can take the result of another pass as input
result_1 = pass_(result)
# It can also take the model as input
result_2 = pass_(result.model)
self.assertIs(result_1.model, result_2.model)


if __name__ == "__main__":
unittest.main()
Loading