Skip to content

Commit d839583

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX][dynamo_export] Skip instance_norm decomp for export (#120866)
Otherwise, instance_norm is decomposed into batch_norm with training set to True. Downstream exporter has no way to figure out that training is actually not needed. On the other hand, ONNX does have InstanceNormalization operator defined, however due to decomp, it unnecessarily exports as batch norm and glue code. Depends on microsoft/onnxscript#1284 Pull Request resolved: #120866 Approved by: https://github.com/thiagocrepaldi, https://github.com/titaiwangms
1 parent 581fe26 commit d839583

File tree

4 files changed

+73
-12
lines changed

4 files changed

+73
-12
lines changed

.ci/docker/common/install_onnx.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ pip_install coloredlogs packaging
3232

3333
pip_install onnxruntime==1.17.0
3434
pip_install onnx==1.15.0
35-
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@1d6362db06706c13447e590ecf5ac3238efc1880" --no-deps
36-
pip_install onnxscript==0.1.0.dev20240216 --no-deps
35+
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps
36+
pip_install onnxscript==0.1.0.dev20240301 --no-deps
3737

3838
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
3939
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

test/onnx/test_fx_op_consistency.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -893,16 +893,6 @@ def skip_torchlib_forward_compatibility(
893893
dtypes=(torch.float16,),
894894
reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"),
895895
),
896-
xfail(
897-
"nn.functional.instance_norm",
898-
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
899-
reason="fixme: Assertion error: result mismatch",
900-
),
901-
xfail(
902-
"nn.functional.instance_norm",
903-
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
904-
reason="Functionalize pass failed",
905-
),
906896
xfail(
907897
"nn.functional.local_response_norm",
908898
dtypes=(torch.int64,),
@@ -1548,6 +1538,13 @@ def skip_torchlib_forward_compatibility(
15481538
"Reshape", "empty tensor"
15491539
),
15501540
),
1541+
xfail(
1542+
"nn.functional.instance_norm",
1543+
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1544+
matcher=lambda sample: sample.kwargs.get("running_mean") is not None
1545+
or sample.input.dtype in (torch.float16,),
1546+
reason="fixme: KeyError: 'self___kwargs__running_mean'",
1547+
),
15511548
xfail(
15521549
"nn.functional.max_pool3d",
15531550
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
@@ -1962,6 +1959,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
19621959
"nn.functional.hardsigmoid": [1e-3, 5e-3],
19631960
"nn.functional.hardswish": [1e-3, 5e-3],
19641961
"nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
1962+
"nn.functional.instance_norm": [1e-2, 1e-3],
19651963
"nn.functional.interpolate": [1e-2, 1e-3],
19661964
"nn.functional.kl_div": [2e-3, 2e-4],
19671965
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],

test/onnx/test_fx_to_onnx_decomp_skip.py

+9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ def func(x: torch.Tensor):
3939
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
4040
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
4141

42+
def test_instance_norm(self):
43+
def func(x: torch.Tensor):
44+
return torch.nn.functional.instance_norm(x)
45+
46+
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
47+
# If decomposition is skipped, the model will contain an InstanceNormalization op
48+
# instead of BatchNormalization op w/ training=True.
49+
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
50+
4251

4352
if __name__ == "__main__":
4453
common_utils.run_tests()

torch/onnx/_internal/fx/decomposition_skip.py

+54
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, Sequence, Type
1919

2020
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
21+
core as torchlib_core,
2122
nn as torchlib_nn,
2223
)
2324

@@ -119,8 +120,61 @@ def abstract(cls, input, output_size, align_corners, scale_factors):
119120
)
120121

121122

123+
class InstanceNormDecompSkip(DecompSkip):
124+
op_callable = torch.instance_norm # type: ignore[attr-defined]
125+
onnxscript_function = torchlib_core.aten_instance_norm # type: ignore[attr-defined]
126+
new_op_name = "instance_norm"
127+
new_op_schema = (
128+
"(Tensor input, Tensor? weight, Tensor? bias, "
129+
"Tensor? running_mean, Tensor? running_var, "
130+
"bool use_input_stats, float momentum, float eps, "
131+
"bool cudnn_enabled) -> Tensor"
132+
)
133+
134+
@classmethod
135+
def register(cls, export_options: torch.onnx.ExportOptions):
136+
if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
137+
torch.ops.onnx_export, cls.new_op_name
138+
):
139+
cls.register_custom_op()
140+
141+
torch.instance_norm = torch.ops.onnx_export.instance_norm # type: ignore[attr-defined]
142+
if export_options.onnx_registry is None:
143+
export_options.onnx_registry = torch.onnx.OnnxRegistry()
144+
registry = export_options.onnx_registry
145+
registry.register_op(
146+
function=cls.onnxscript_function,
147+
namespace=_NEW_OP_NAMESPACE,
148+
op_name=cls.new_op_name,
149+
)
150+
151+
@classmethod
152+
def unregister(cls):
153+
torch.instance_norm = cls.op_callable # type: ignore[attr-defined]
154+
155+
@classmethod
156+
def abstract(
157+
cls,
158+
input,
159+
weight,
160+
bias,
161+
running_mean,
162+
running_var,
163+
use_input_stats: bool,
164+
momentum: float,
165+
eps: float,
166+
cudnn_enabled: bool,
167+
):
168+
return torch.empty(
169+
input.size(),
170+
dtype=input.dtype,
171+
device=input.device,
172+
)
173+
174+
122175
_DEFAULT_SKIP_LIST = [
123176
UpsampleBilinear2DDecompSkip,
177+
InstanceNormDecompSkip,
124178
]
125179

126180

0 commit comments

Comments
 (0)