diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 83e113154a..e9bb172173 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2046,7 +2046,7 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op("aten::copy") +@torch_op(("aten::copy", "aten::_to_copy")) def aten_copy( self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument ) -> TTensor: @@ -5456,6 +5456,20 @@ def aten__native_batch_norm_no_training( ) +@torch_op("aten::_native_batch_norm_legit.no_stats", trace_only=True) +def aten__native_batch_norm_no_stats( + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + training: bool = False, + momentum: float = 0.9, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat]: + """_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" + + return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps) + + @torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True) def aten_native_batch_norm( input: TFloat, @@ -5556,12 +5570,131 @@ def _aten_native_batch_norm_inference_onnx( momentum=momentum, training_mode=training, ) + # NOTE: mean and var are omitted in inference mode # Cannot return 2 dup output, so have to do twice with different variable name - empty_mean = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype) - empty_var = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype) + empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm) + empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm) return norm, empty_mean, empty_var +# TODO: This op is using duplicated code from aten_native_batch_norm, +# need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125 +# NOTE: This op is invoked by PyTorch Functionalization, and not in +# native_functions.yaml, It can be found in torch/_decomp/decompositions.py +@torch_op("aten::_native_batch_norm_legit_functional", trace_only=True) +def aten__native_batch_norm_legit_functional( + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + running_mean: Optional[TFloat] = None, + running_var: Optional[TFloat] = None, + training: bool = False, + momentum: float = 0.9, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + if weight is None: # Set to 1.0 as default + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + axes = list(range(len(input.shape))) + axes.pop(1) + axes = op.Constant(value_ints=axes) + if running_mean is None: # Using input mean + running_mean = op.Squeeze(op.ReduceMean(input, axes)) + + if running_var is None: # Using input var + mean = op.ReduceMean(input, axes) + input_sub_mean = op.Sub(input, mean) + sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) + running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes)) + + # Have to split to 2 private functions, because training_function return 3 outputs + # While inference_function return 1 output + if training is True: + norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx( + input, weight, bias, running_mean, running_var, axes, training, momentum, eps + ) + else: + ( + norm, + mean, + var, + new_mean, + new_var, + ) = _aten__native_batch_norm_inference_functional_onnx( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + return norm, mean, var, new_mean, new_var + + +@torch_op("aten::_native_batch_norm_legit_functional", private=True) +def _aten__native_batch_norm_training_functional_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + running_mean: TFloat, + running_var: TFloat, + axes: INT64, + training: bool, + momentum: float, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + # Assert(training is True) + norm, running_mean, running_var = op.BatchNormalization( + input, + weight, + bias, + running_mean, + running_var, + epsilon=eps, + momentum=momentum, + training_mode=training, + ) + # Compute var and rstd + mean = op.ReduceMean(input, axes) + input_sub_mean = op.Sub(input, mean) + sqr = op.Mul(input_sub_mean, input_sub_mean) + var = op.ReduceMean(sqr, axes, keepdims=False) + rstd = op.Div(1.0, op.Sqrt(var + eps)) + # Get mean again with size = [1, C] + mean = op.ReduceMean(input, axes, keepdims=False) + # NOTE: Fixed to be FLOAT dtype + running_mean = op.Cast(running_mean, to=FLOAT.dtype) + running_var = op.Cast(running_var, to=FLOAT.dtype) + return norm, mean, rstd, running_mean, running_var + + +@torch_op("aten::_native_batch_norm_legit_functional", private=True) +def _aten__native_batch_norm_inference_functional_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + running_mean: TFloat, + running_var: TFloat, + training: bool, + momentum: float, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + # Assert(training is False) + norm = op.BatchNormalization( + input, + weight, + bias, + running_mean, + running_var, + epsilon=eps, + momentum=momentum, + training_mode=training, + ) + # NOTE: mean and var are ommited in inference mode + # Cannot return 2 dup output, so have to do twice with different variable name + empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm) + empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm) + return norm, empty_mean, empty_var, running_mean, running_var + + def aten_native_batch_norm_backward( grad_out: TensorType, input: TensorType, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 9f26d75af6..253d98b9f2 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1289,6 +1289,52 @@ def sample_inputs_scaled_dot_product_flash_attention( yield from samples +# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args: +# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps) +# 2. (input, weight, bias, training, momentum, eps) +# which requires two function signatures to take the inputs, that's why we have +# two sample_inputs functions here instead. +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs + ) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is not None: + yield opinfo_core.SampleInput( + sample.input, + args=(args[2], args[3], args[0], args[1], training, momentum, eps), + ) + + +def sample_inputs__native_batch_norm_legit_no_stats( + op_info, device, dtype, requires_grad, **kwargs +): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs + ) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is None: + yield opinfo_core.SampleInput( + sample.input, args=(args[2], args[3], training, momentum, eps) + ) + + # NOTE: How to create an OpInfo: # 1. Create a function that generates sample inputs for the op. # This function should yield SampleInputs. @@ -1633,4 +1679,34 @@ def sample_inputs_scaled_dot_product_flash_attention( supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit", + aten_name="_native_batch_norm_legit", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit_functional", + aten_name="_native_batch_norm_legit_functional", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit.no_stats", + aten_name="_native_batch_norm_legit.no_stats", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 33c3602be1..47957e1b47 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -184,7 +184,6 @@ def run_test_output_match( # Obtain the tolerance for the op rtol, atol = torchlib_op_info.get_tolerance(dtype) - for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args) # Provide the repr to subtest because tensors are not serializable in parallel test runs diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index b0e1e527d9..5510681957 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1680,6 +1680,20 @@ def _where_input_wrangler( reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975", ), TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True + ), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit.no_stats", + core_ops.aten__native_batch_norm_no_stats, + trace_only=True, + ), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit_functional", + core_ops.aten__native_batch_norm_legit_functional, + trace_only=True, + compare_shape_only_for_output=(3, 4), + ), TorchLibOpInfo( "ops.aten.native_group_norm", core_ops.aten_native_group_norm,