Skip to content

Commit 1d79e4c

Browse files
committed
[Test] Hack to enable optimizer/rewriter integration into dynamo_export
- To test in torchbench. - Somehow lintrunner changed unrelated files in this commit. ghstack-source-id: 0180251 Pull Request resolved: #1334
1 parent 4393bc1 commit 1d79e4c

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from typing_extensions import TypeAlias
1818

1919
import onnxscript
20+
21+
# Q: Don't know why vscode only recognizes this import style instead of `from onnxscript import optimizer`.
22+
import onnxscript.optimizer as optimizer
2023
from onnxscript import evaluator
2124
from onnxscript import tensor as onnxscript_tensor
2225
from onnxscript._internal import param_manipulation, runtime_typing
2326
from onnxscript.function_libs.torch_lib import _flags
2427
from onnxscript.function_libs.torch_lib.ops import common as common_ops
28+
from onnxscript.rewriter import onnxruntime as ort_rewriter
2529

2630
__all__ = [
2731
"TorchScriptTensor",
@@ -1071,4 +1075,9 @@ def to_model_proto(
10711075
common_ops.common_opset.domain, common_ops.common_opset.version
10721076
)
10731077
)
1078+
1079+
# Not the best integration point. Enables benchmarking the migration.
1080+
onnx_model = optimizer.optimize(onnx_model)
1081+
# This also creates contrib op in the model. So definitely not the best integration point.
1082+
onnx_model = ort_rewriter.rewrite(onnx_model)
10741083
return onnx_model

onnxscript/optimizer/constant_folding.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def is_non_deterministic_op(node: onnx.NodeProto) -> bool:
4141

4242

4343
def is_constant_op(node: onnx.NodeProto) -> bool:
44-
return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(
45-
node.domain
46-
)
44+
return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain)
4745

4846

4947
class ConstantFolder(visitor.FunctionCallsiteProtoTransformer):
@@ -119,14 +117,10 @@ def new_constant(self, name, value):
119117
info.type = onnx.helper.make_tensor_type_proto(
120118
onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape
121119
)
122-
node = onnx.helper.make_node(
123-
"Constant", inputs=[], outputs=[name], value=tensor
124-
)
120+
node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor)
125121
return [node]
126122

127-
def convert_attributes(
128-
self, attributes: Sequence[onnx.AttributeProto]
129-
) -> dict[str, Any]:
123+
def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]:
130124
if self.scopes.current_scope().current_function_scope():
131125
# Need to resolve ref_attr_name if inside a function.
132126
attr_dict = {}
@@ -138,9 +132,7 @@ def convert_attributes(
138132
)
139133
if concrete_attribute is None:
140134
continue
141-
attr_dict[attribute.name] = onnx.helper.get_attribute_value(
142-
concrete_attribute
143-
)
135+
attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute)
144136
return attr_dict
145137
return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes}
146138

@@ -226,9 +218,7 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
226218
self.add_count(op, outputs.size)
227219
return replacement
228220
else:
229-
logger.warning(
230-
"Skipping constant folding for op %s with multiple outputs.", op
231-
)
221+
logger.warning("Skipping constant folding for op %s with multiple outputs.", op)
232222
return None
233223

234224
def process_function_node(
@@ -241,9 +231,7 @@ def process_function_node(
241231
# Replace function node with Constant if all outputs are constants
242232
ir_values = [self.lookup(output_name) for output_name in node.output]
243233
tensors = [
244-
self.foldable_value(
245-
output_name, ir_value.value if ir_value is not None else None
246-
)
234+
self.foldable_value(output_name, ir_value.value if ir_value is not None else None)
247235
for output_name, ir_value in zip(node.output, ir_values)
248236
]
249237
if all(tensor is not None for tensor in tensors):

0 commit comments

Comments
 (0)