diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c4c9dc5c4b..6ac1345522 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4168,20 +4168,58 @@ def aten_inner(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::instance_norm", trace_only=True) def aten_instance_norm( - input: TensorType, - weight: Optional[TensorType], - bias: Optional[TensorType], - running_mean: Optional[TensorType], - running_var: Optional[TensorType], - use_input_stats: bool, - momentum: float, - eps: float, - cudnn_enabled: bool, -) -> TensorType: + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + running_mean: Optional[TFloat] = None, + running_var: Optional[TFloat] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-05, + cudnn_enabled: bool = False, +) -> TFloat: """instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor""" + del cudnn_enabled # unused + if weight is None: # Set to 1.0 as default + weight = op.CastLike( + op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)), input + ) - raise NotImplementedError() + if bias is None: # Set to 0.0 as default + bias = op.CastLike( + op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)), input + ) + + # If `use_input_stats` is set to True, ignore 'running_mean' and 'running_var' and + # compute using input statistics. + # Otherwise, compute using the running statistics. + if use_input_stats: + return op.InstanceNormalization(input, weight, bias, epsilon=eps) + + assert ( + running_mean is not None and running_var is not None + ), "running_mean and running_var must be provided when use_input_stats is False" + + batch_size = op.Shape(input, start=0, end=1) + bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0)) + weight = op.Tile(weight, batch_size) + bias = op.Tile(bias, batch_size) + running_mean = op.Tile(running_mean, batch_size) + running_var = op.Tile(running_var, batch_size) + + norm = op.BatchNormalization( + bn_input, + weight, + bias, + running_mean, + running_var, + epsilon=eps, + momentum=1 - momentum, + training_mode=False, + ) + return op.Reshape(norm, op.Shape(input)) def aten_int_repr(self: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 172d183f28..44d3708872 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -292,7 +292,7 @@ def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: if key == "dtype": value = TORCH_TYPE_TO_ONNX[value] if isinstance(value, torch.Tensor): - value = np.array(value) + value = np.array(value.cpu()) new_kwargs[key] = value return new_kwargs diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 3cfd4a1629..8373d784bc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1920,6 +1920,12 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str), reason="String padding is not accepted by aten::conv2d", ), + TorchLibOpInfo( + "nn.functional.instance_norm", + core_ops.aten_instance_norm, + trace_only=True, + tolerance={torch.float16: (1e-2, 1e-3)}, + ), TorchLibOpInfo( "ops.aten.conv3d", core_ops.aten_conv3d,