Skip to content

Commit 88ee668

Browse files
Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) (#1143)
Fix #1140 Add (1) `aten::_native_batch_norm_legit.no_stats` (2) `aten::_to_copy` (3) `aten::_native_batch_norm_legit_functional` `aten::_native_batch_norm_legit_functional` is only invoked by Functionalization pass, so it can't be tested in op_test. It will be added into op_test in converter side. The only difference btween the op and `aten::_native_batch_norm_legit` is the output numbers. `aten::_native_batch_norm_legit_functional` returns running_mean and running_var according to https://github.com/pytorch/pytorch/blob/1488bafb274fcc82c8aac429bad61738bc3f950e/torch/_decomp/decompositions.py#L1804-L1826 `aten_native_batch_norm_legit` is split into two sample inputs to separately feed into different ONNX variants, since they require different set of arguments. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent fdef96c commit 88ee668

File tree

4 files changed

+226
-4
lines changed

4 files changed

+226
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,7 @@ def aten_convolution_overrideable(
20462046
raise NotImplementedError()
20472047

20482048

2049-
@torch_op("aten::copy")
2049+
@torch_op(("aten::copy", "aten::_to_copy"))
20502050
def aten_copy(
20512051
self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument
20522052
) -> TTensor:
@@ -5456,6 +5456,20 @@ def aten__native_batch_norm_no_training(
54565456
)
54575457

54585458

5459+
@torch_op("aten::_native_batch_norm_legit.no_stats", trace_only=True)
5460+
def aten__native_batch_norm_no_stats(
5461+
input: TFloat,
5462+
weight: Optional[TFloat] = None,
5463+
bias: Optional[TFloat] = None,
5464+
training: bool = False,
5465+
momentum: float = 0.9,
5466+
eps: float = 1e-05,
5467+
) -> Tuple[TFloat, TFloat, TFloat]:
5468+
"""_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""
5469+
5470+
return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps)
5471+
5472+
54595473
@torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True)
54605474
def aten_native_batch_norm(
54615475
input: TFloat,
@@ -5556,12 +5570,131 @@ def _aten_native_batch_norm_inference_onnx(
55565570
momentum=momentum,
55575571
training_mode=training,
55585572
)
5573+
# NOTE: mean and var are omitted in inference mode
55595574
# Cannot return 2 dup output, so have to do twice with different variable name
5560-
empty_mean = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
5561-
empty_var = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
5575+
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
5576+
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
55625577
return norm, empty_mean, empty_var
55635578

55645579

5580+
# TODO: This op is using duplicated code from aten_native_batch_norm,
5581+
# need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125
5582+
# NOTE: This op is invoked by PyTorch Functionalization, and not in
5583+
# native_functions.yaml, It can be found in torch/_decomp/decompositions.py
5584+
@torch_op("aten::_native_batch_norm_legit_functional", trace_only=True)
5585+
def aten__native_batch_norm_legit_functional(
5586+
input: TFloat,
5587+
weight: Optional[TFloat] = None,
5588+
bias: Optional[TFloat] = None,
5589+
running_mean: Optional[TFloat] = None,
5590+
running_var: Optional[TFloat] = None,
5591+
training: bool = False,
5592+
momentum: float = 0.9,
5593+
eps: float = 1e-05,
5594+
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
5595+
if weight is None: # Set to 1.0 as default
5596+
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))
5597+
5598+
if bias is None: # Set to 0.0 as default
5599+
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))
5600+
5601+
axes = list(range(len(input.shape)))
5602+
axes.pop(1)
5603+
axes = op.Constant(value_ints=axes)
5604+
if running_mean is None: # Using input mean
5605+
running_mean = op.Squeeze(op.ReduceMean(input, axes))
5606+
5607+
if running_var is None: # Using input var
5608+
mean = op.ReduceMean(input, axes)
5609+
input_sub_mean = op.Sub(input, mean)
5610+
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
5611+
running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))
5612+
5613+
# Have to split to 2 private functions, because training_function return 3 outputs
5614+
# While inference_function return 1 output
5615+
if training is True:
5616+
norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx(
5617+
input, weight, bias, running_mean, running_var, axes, training, momentum, eps
5618+
)
5619+
else:
5620+
(
5621+
norm,
5622+
mean,
5623+
var,
5624+
new_mean,
5625+
new_var,
5626+
) = _aten__native_batch_norm_inference_functional_onnx(
5627+
input, weight, bias, running_mean, running_var, training, momentum, eps
5628+
)
5629+
return norm, mean, var, new_mean, new_var
5630+
5631+
5632+
@torch_op("aten::_native_batch_norm_legit_functional", private=True)
5633+
def _aten__native_batch_norm_training_functional_onnx(
5634+
input: TFloat,
5635+
weight: TFloat,
5636+
bias: TFloat,
5637+
running_mean: TFloat,
5638+
running_var: TFloat,
5639+
axes: INT64,
5640+
training: bool,
5641+
momentum: float,
5642+
eps: float,
5643+
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
5644+
# Assert(training is True)
5645+
norm, running_mean, running_var = op.BatchNormalization(
5646+
input,
5647+
weight,
5648+
bias,
5649+
running_mean,
5650+
running_var,
5651+
epsilon=eps,
5652+
momentum=momentum,
5653+
training_mode=training,
5654+
)
5655+
# Compute var and rstd
5656+
mean = op.ReduceMean(input, axes)
5657+
input_sub_mean = op.Sub(input, mean)
5658+
sqr = op.Mul(input_sub_mean, input_sub_mean)
5659+
var = op.ReduceMean(sqr, axes, keepdims=False)
5660+
rstd = op.Div(1.0, op.Sqrt(var + eps))
5661+
# Get mean again with size = [1, C]
5662+
mean = op.ReduceMean(input, axes, keepdims=False)
5663+
# NOTE: Fixed to be FLOAT dtype
5664+
running_mean = op.Cast(running_mean, to=FLOAT.dtype)
5665+
running_var = op.Cast(running_var, to=FLOAT.dtype)
5666+
return norm, mean, rstd, running_mean, running_var
5667+
5668+
5669+
@torch_op("aten::_native_batch_norm_legit_functional", private=True)
5670+
def _aten__native_batch_norm_inference_functional_onnx(
5671+
input: TFloat,
5672+
weight: TFloat,
5673+
bias: TFloat,
5674+
running_mean: TFloat,
5675+
running_var: TFloat,
5676+
training: bool,
5677+
momentum: float,
5678+
eps: float,
5679+
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
5680+
# Assert(training is False)
5681+
norm = op.BatchNormalization(
5682+
input,
5683+
weight,
5684+
bias,
5685+
running_mean,
5686+
running_var,
5687+
epsilon=eps,
5688+
momentum=momentum,
5689+
training_mode=training,
5690+
)
5691+
# NOTE: mean and var are ommited in inference mode
5692+
# Cannot return 2 dup output, so have to do twice with different variable name
5693+
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
5694+
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
5695+
return norm, empty_mean, empty_var, running_mean, running_var
5696+
5697+
55655698
def aten_native_batch_norm_backward(
55665699
grad_out: TensorType,
55675700
input: TensorType,

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,52 @@ def sample_inputs_scaled_dot_product_flash_attention(
12891289
yield from samples
12901290

12911291

1292+
# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args:
1293+
# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps)
1294+
# 2. (input, weight, bias, training, momentum, eps)
1295+
# which requires two function signatures to take the inputs, that's why we have
1296+
# two sample_inputs functions here instead.
1297+
def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
1298+
samples = common_methods_invocations.sample_inputs_batch_norm(
1299+
op_info, device, dtype, requires_grad, **kwargs
1300+
)
1301+
for sample in samples:
1302+
# torch.native_batch_norm does not support 0 numel tensors
1303+
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
1304+
if sample.input.numel() == 0:
1305+
continue
1306+
args = sample.args
1307+
training = sample.kwargs.get("training", True)
1308+
momentum = sample.kwargs.get("momentum", 0.5)
1309+
eps = sample.kwargs.get("eps", 1e-5)
1310+
if args[0] is not None and args[1] is not None:
1311+
yield opinfo_core.SampleInput(
1312+
sample.input,
1313+
args=(args[2], args[3], args[0], args[1], training, momentum, eps),
1314+
)
1315+
1316+
1317+
def sample_inputs__native_batch_norm_legit_no_stats(
1318+
op_info, device, dtype, requires_grad, **kwargs
1319+
):
1320+
samples = common_methods_invocations.sample_inputs_batch_norm(
1321+
op_info, device, dtype, requires_grad, **kwargs
1322+
)
1323+
for sample in samples:
1324+
# torch.native_batch_norm does not support 0 numel tensors
1325+
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
1326+
if sample.input.numel() == 0:
1327+
continue
1328+
args = sample.args
1329+
training = sample.kwargs.get("training", True)
1330+
momentum = sample.kwargs.get("momentum", 0.5)
1331+
eps = sample.kwargs.get("eps", 1e-5)
1332+
if args[0] is not None and args[1] is None:
1333+
yield opinfo_core.SampleInput(
1334+
sample.input, args=(args[2], args[3], training, momentum, eps)
1335+
)
1336+
1337+
12921338
# NOTE: How to create an OpInfo:
12931339
# 1. Create a function that generates sample inputs for the op.
12941340
# This function should yield SampleInputs.
@@ -1633,4 +1679,34 @@ def sample_inputs_scaled_dot_product_flash_attention(
16331679
supports_fwgrad_bwgrad=True,
16341680
check_batched_forward_grad=False,
16351681
),
1682+
opinfo_core.OpInfo(
1683+
"ops.aten._native_batch_norm_legit",
1684+
aten_name="_native_batch_norm_legit",
1685+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1686+
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
1687+
supports_forward_ad=True,
1688+
supports_fwgrad_bwgrad=True,
1689+
assert_jit_shape_analysis=True,
1690+
sample_inputs_func=sample_inputs__native_batch_norm_legit,
1691+
),
1692+
opinfo_core.OpInfo(
1693+
"ops.aten._native_batch_norm_legit_functional",
1694+
aten_name="_native_batch_norm_legit_functional",
1695+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1696+
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
1697+
supports_forward_ad=True,
1698+
supports_fwgrad_bwgrad=True,
1699+
assert_jit_shape_analysis=True,
1700+
sample_inputs_func=sample_inputs__native_batch_norm_legit,
1701+
),
1702+
opinfo_core.OpInfo(
1703+
"ops.aten._native_batch_norm_legit.no_stats",
1704+
aten_name="_native_batch_norm_legit.no_stats",
1705+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1706+
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
1707+
supports_forward_ad=True,
1708+
supports_fwgrad_bwgrad=True,
1709+
assert_jit_shape_analysis=True,
1710+
sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats,
1711+
),
16361712
]

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def run_test_output_match(
184184

185185
# Obtain the tolerance for the op
186186
rtol, atol = torchlib_op_info.get_tolerance(dtype)
187-
188187
for i, cpu_sample in enumerate(samples):
189188
inputs = (cpu_sample.input, *cpu_sample.args)
190189
# Provide the repr to subtest because tensors are not serializable in parallel test runs

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,20 @@ def _where_input_wrangler(
16801680
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
16811681
),
16821682
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
1683+
TorchLibOpInfo(
1684+
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
1685+
),
1686+
TorchLibOpInfo(
1687+
"ops.aten._native_batch_norm_legit.no_stats",
1688+
core_ops.aten__native_batch_norm_no_stats,
1689+
trace_only=True,
1690+
),
1691+
TorchLibOpInfo(
1692+
"ops.aten._native_batch_norm_legit_functional",
1693+
core_ops.aten__native_batch_norm_legit_functional,
1694+
trace_only=True,
1695+
compare_shape_only_for_output=(3, 4),
1696+
),
16831697
TorchLibOpInfo(
16841698
"ops.aten.native_group_norm",
16851699
core_ops.aten_native_group_norm,

0 commit comments

Comments
 (0)