Skip to content

Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 136 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def aten_convolution_overrideable(
raise NotImplementedError()


@torch_op("aten::copy")
@torch_op(("aten::copy", "aten::_to_copy"))
def aten_copy(
self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument
) -> TTensor:
Expand Down Expand Up @@ -5456,6 +5456,20 @@ def aten__native_batch_norm_no_training(
)


@torch_op("aten::_native_batch_norm_legit.no_stats", trace_only=True)
def aten__native_batch_norm_no_stats(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
"""_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""

return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps)


@torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True)
def aten_native_batch_norm(
input: TFloat,
Expand Down Expand Up @@ -5556,12 +5570,131 @@ def _aten_native_batch_norm_inference_onnx(
momentum=momentum,
training_mode=training,
)
# NOTE: mean and var are omitted in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
empty_var = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var


# TODO: This op is using duplicated code from aten_native_batch_norm,
# need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125
# NOTE: This op is invoked by PyTorch Functionalization, and not in
# native_functions.yaml, It can be found in torch/_decomp/decompositions.py
Comment on lines +5582 to +5583
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we file an issue in pytorch to include it in the yaml?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torch_op("aten::_native_batch_norm_legit_functional", trace_only=True)
def aten__native_batch_norm_legit_functional(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
running_mean: Optional[TFloat] = None,
running_var: Optional[TFloat] = None,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
if weight is None: # Set to 1.0 as default
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

axes = list(range(len(input.shape)))
axes.pop(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to add a comment on why we pop index 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this code is duplicated from aten_native_batch_norm. They are only different from the output numbers. Maybe @xiaowuhu can chime in and answer this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use functional: bool to differentiate them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to add a todo for now and create an issue to track

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use functional: bool to differentiate them.

I think it's fine for now.

axes = op.Constant(value_ints=axes)
if running_mean is None: # Using input mean
running_mean = op.Squeeze(op.ReduceMean(input, axes))

if running_var is None: # Using input var
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))

# Have to split to 2 private functions, because training_function return 3 outputs
# While inference_function return 1 output
if training is True:
norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx(
input, weight, bias, running_mean, running_var, axes, training, momentum, eps
)
else:
(
norm,
mean,
var,
new_mean,
new_var,
) = _aten__native_batch_norm_inference_functional_onnx(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
return norm, mean, var, new_mean, new_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_training_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
axes: INT64,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is True)
norm, running_mean, running_var = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# Compute var and rstd
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr = op.Mul(input_sub_mean, input_sub_mean)
var = op.ReduceMean(sqr, axes, keepdims=False)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(input, axes, keepdims=False)
# NOTE: Fixed to be FLOAT dtype
running_mean = op.Cast(running_mean, to=FLOAT.dtype)
running_var = op.Cast(running_var, to=FLOAT.dtype)
return norm, mean, rstd, running_mean, running_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_inference_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is False)
norm = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# NOTE: mean and var are ommited in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var, running_mean, running_var
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note saying we omitted computing mean and var so readers know to implement them when needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def aten_native_batch_norm_backward(
grad_out: TensorType,
input: TensorType,
Expand Down
76 changes: 76 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,52 @@ def sample_inputs_scaled_dot_product_flash_attention(
yield from samples


# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args:
# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps)
# 2. (input, weight, bias, training, momentum, eps)
# which requires two function signatures to take the inputs, that's why we have
# two sample_inputs functions here instead.
def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
samples = common_methods_invocations.sample_inputs_batch_norm(
op_info, device, dtype, requires_grad, **kwargs
)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
training = sample.kwargs.get("training", True)
momentum = sample.kwargs.get("momentum", 0.5)
eps = sample.kwargs.get("eps", 1e-5)
if args[0] is not None and args[1] is not None:
yield opinfo_core.SampleInput(
sample.input,
args=(args[2], args[3], args[0], args[1], training, momentum, eps),
)


def sample_inputs__native_batch_norm_legit_no_stats(
op_info, device, dtype, requires_grad, **kwargs
):
samples = common_methods_invocations.sample_inputs_batch_norm(
op_info, device, dtype, requires_grad, **kwargs
)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
training = sample.kwargs.get("training", True)
momentum = sample.kwargs.get("momentum", 0.5)
eps = sample.kwargs.get("eps", 1e-5)
if args[0] is not None and args[1] is None:
yield opinfo_core.SampleInput(
sample.input, args=(args[2], args[3], training, momentum, eps)
)


# NOTE: How to create an OpInfo:
# 1. Create a function that generates sample inputs for the op.
# This function should yield SampleInputs.
Expand Down Expand Up @@ -1633,4 +1679,34 @@ def sample_inputs_scaled_dot_product_flash_attention(
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit",
aten_name="_native_batch_norm_legit",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit_functional",
aten_name="_native_batch_norm_legit_functional",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
aten_name="_native_batch_norm_legit.no_stats",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats,
),
]
1 change: 0 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def run_test_output_match(

# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)

for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,20 @@ def _where_input_wrangler(
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
),
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
core_ops.aten__native_batch_norm_no_stats,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit_functional",
core_ops.aten__native_batch_norm_legit_functional,
trace_only=True,
compare_shape_only_for_output=(3, 4),
),
TorchLibOpInfo(
"ops.aten.native_group_norm",
core_ops.aten_native_group_norm,
Expand Down