Skip to content

A couple of ort fusion fixes #2136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
# TODO: There are some potential redundancies below. Can be targeted for optimization
# once we have robust fusion.
def _pre_optimize(model: ir.Model) -> ir.Model:
optimize(model)
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
Expand All @@ -45,7 +44,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
return model


def fuse_xformers(model: ir.Model) -> None:
def fuse_xformers(model: ir.Model) -> ir.Model:
model = _pre_optimize(model)
fuse_rms_normalization(model)
fuse_normalization(model)
Expand All @@ -58,8 +57,26 @@ def fuse_xformers(model: ir.Model) -> None:
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
return model


def optimize_for_ort(model: ir.Model, config_name: str | None = None) -> ir.Model:
"""
Optimize the model for ORT backend.

TODO: config_name is not used yet. It should be used to select the appropriate
optimization configuration (for an EP). Currently, a default implementation is used.

Args:
model: The model to optimize.
config_name: The name of the configuration to use for optimization.
Typically it identifies the Execution Provider (EP) to optimize for.
If None, the default configuration will be used.

Returns:
The optimized model.
"""

def optimize_for_ort(model: ir.Model) -> None:
fuse_xformers(model)
rewrite(model, ORT_PATTERN_REWRITE_RULES)
return model
5 changes: 3 additions & 2 deletions onnxscript/rewriter/ort_fusions/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def ort_run(model_name: str, model, inputs):
def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol)
np.testing.assert_allclose(
baseline_output, optimized_output, rtol=rtol, atol=atol, strict=True
)
except AssertionError as e:
print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}")
raise
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/fuse_xformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_fuse_xformers(self):
onnxscript.optimizer.optimize(model)
inputs = test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
fuse_xformers(model)
model = fuse_xformers(model)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ def copy_value(value: ir.Value | None) -> ir.Value | None:
return None
if value not in value_map:
const_value = value.const_value
if isinstance(const_value, (ir.Tensor, ir.TensorProtoTensor)):
if const_value is not None:
# create a Constant node to represent the value
value_attr = ir.AttrTensor("value", const_value)
const_node = ir.Node("", "Constant", [], [value_attr])
Expand Down
Loading