Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
dcae05e
add ops
xiaowuhu Feb 22, 2023
92a0615
Update ops_correctness_test.py
xiaowuhu Feb 22, 2023
bcd41c9
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 23, 2023
9ccb944
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 24, 2023
7a741de
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 24, 2023
b57275a
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 27, 2023
537ad3c
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(layer…
xiaowuhu Feb 27, 2023
8bc048e
Merge branch 'xiaowu/addOps(layer_norm)' of https://github.com/xiaowu…
xiaowuhu Feb 27, 2023
4e77095
fix comment
xiaowuhu Feb 27, 2023
9bc68d5
fix comments
xiaowuhu Feb 27, 2023
bebab5c
Update extra_opinfo.py
xiaowuhu Feb 27, 2023
1c0e629
Update extra_opinfo.py
xiaowuhu Feb 27, 2023
90f41f5
Update extra_opinfo.py
xiaowuhu Feb 27, 2023
f879eb3
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 28, 2023
b692a25
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Feb 28, 2023
3d8ab9a
fix comment
xiaowuhu Mar 1, 2023
fbbdeee
Merge branch 'main' into xiaowu/addOps(layer_norm)
xiaowuhu Mar 1, 2023
304324f
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(layer…
xiaowuhu Mar 1, 2023
0570b92
Merge branch 'xiaowu/addOps(layer_norm)' of https://github.com/xiaowu…
xiaowuhu Mar 1, 2023
58173d6
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(layer…
xiaowuhu Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2893,17 +2893,48 @@ def aten_kthvalue(
raise NotImplementedError()


@torch_op("aten::layer_norm", trace_only=True)
def aten_layer_norm(
input: TensorType,
input: TReal,
normalized_shape: Sequence[int],
weight: Optional[TensorType] = None,
bias: Optional[TensorType] = None,
weight: Optional[TReal] = None,
bias: Optional[TReal] = None,
eps: float = 1e-05,
cudnn_enable: bool = True,
) -> TensorType:
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""

raise NotImplementedError()
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
axes = op.Constant(value_ints=axes_list)
if not op.OptionalHasElement(weight):
weight = op.CastLike(1, input)
if not op.OptionalHasElement(bias):
bias = op.CastLike(0, input)

result = _aten_layer_norm_onnx(input, weight, bias, axes, eps)
return result


@torch_op("aten::layer_norm", overload=True)
def _aten_layer_norm_onnx(
input: TReal,
weight: TReal,
bias: TReal,
axes: Sequence[int],
eps: float,
) -> TReal:

mean = op.ReduceMean(input, axes)
numerator = op.Sub(input, mean)
power_num = op.Pow(numerator, 2.0)
variance = op.ReduceMean(power_num, axes)
variance_eps = op.Add(variance, eps)
denominator = op.Sqrt(variance_eps)
result = op.Div(numerator, denominator)
weight = op.CastLike(weight, result)
result = op.Mul(result, weight)
bias = op.CastLike(bias, result)
result = op.Add(result, bias)
return result


def aten_lcm(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -3918,12 +3949,13 @@ def aten_native_layer_norm(
# where D is the dimension of normalized_shape. For example, if normalized_shape is
# (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
# over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
axes = [-i for i in range(len(normalized_shape), 0, -1)]
if weight is None:
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
axes = op.Constant(value_ints=axes_list)
if not op.OptionalHasElement(weight):
weight = op.CastLike(1, input)
if bias is None:
if not op.OptionalHasElement(bias):
bias = op.CastLike(0, input)
return _aten_native_layer_norm_onnx(input, weight, bias, axes=axes, eps=eps)
return _aten_native_layer_norm_onnx(input, weight, bias, axes, eps)


@torch_op("aten::native_layer_norm", overload=True)
Expand All @@ -3936,18 +3968,18 @@ def _aten_native_layer_norm_onnx(
) -> Tuple[TReal, TReal, TReal]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axes should be input now.

axes: Sequence[INT64],

And other places using ReduceMax/ReduceMean/ReduceMin should all be updated, otherwise, the model would pop errors saying ReduceXXX having unexpected input/attribute axes depending on opset_verison 17/18.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. The interface of an aten op/function shouldn't change because of its implementation. Since an attribute can be promoted to be an input, it should be okay to leave the interface as is. But it is important to ensure that the implementation works correctly. Eg., use the noop_with_empty_axes attribute as appropriate to ensure correct behavior for when axes is empty, etc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to the annotation (Sequence[INT64]) or usage (op.ReduceMean(input, axes=axes))? I think both of them needed to be modified to be compatible with opset18?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying that the annotation should ideally depend on the aten specification (and not on the onnx opset). The call to onnx ops like ReduceMean, of course, will depend on the onnx opset.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, but for onnx function (aten interface), I think we still rely on function_ir to differentiate attrs/inputs. I guess I am not sure if function_ir changed due to the opset version bump 17 -> 18 in this case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea might be related to #443

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying that the annotation should ideally depend on the aten specification (and not on the onnx opset). The call to onnx ops like ReduceMean, of course, will depend on the onnx opset.

I agree. We could change it to INT64 (instead of Sequence[INT64]) prepare for symint inputs. The evaluator should be able to handle this (maybe already)


# FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Remove

mean = opset17.ReduceMean(input, axes=axes)
numerator = opset17.Sub(input, mean)
power_num = opset17.Pow(numerator, 2.0)
variance = opset17.ReduceMean(power_num, axes=axes)
variance_eps = opset17.Add(variance, eps)
denominator = opset17.Sqrt(variance_eps)
result = opset17.Div(numerator, denominator)
weight = opset17.CastLike(weight, result)
result = opset17.Mul(result, weight)
bias = opset17.CastLike(bias, result)
result = opset17.Add(result, bias)
rdenominator = opset17.Reciprocal(denominator)
mean = op.ReduceMean(input, axes)
numerator = op.Sub(input, mean)
power_num = op.Pow(numerator, 2.0)
variance = op.ReduceMean(power_num, axes)
variance_eps = op.Add(variance, eps)
denominator = op.Sqrt(variance_eps)
result = op.Div(numerator, denominator)
weight = op.CastLike(weight, result)
result = op.Mul(result, weight)
bias = op.CastLike(bias, result)
result = op.Add(result, bias)
rdenominator = op.Reciprocal(denominator)
return result, mean, rdenominator


Expand Down
50 changes: 50 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,44 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
)


def sample_inputs_layer_norm(
op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument
):
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

# Ordered as input shape, normalized_shape, eps
cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment]
((1, 2, 3), (1, 2, 3), 0.5),
((2, 2, 3), (2, 3), -0.5),
((1,), (1,), 1e-5),
((1, 2), (2,), 1e-5),
((0, 1), (1,), 1e-5),
)

for input_shape, normalized_shape, eps in cases: # type: ignore[misc]
# Shape of weight and bias should be the same as normalized_shape
weight = make_arg(normalized_shape) # type: ignore[has-type]
bias = make_arg(normalized_shape) # type: ignore[has-type]
yield opinfo_core.SampleInput(
make_arg(input_shape), # type: ignore[has-type]
args=(normalized_shape, weight, bias, eps), # type: ignore[has-type]
)
yield opinfo_core.SampleInput(
make_arg(input_shape), # type: ignore[has-type]
args=(normalized_shape, None, bias, eps), # type: ignore[has-type]
)
yield opinfo_core.SampleInput(
make_arg(input_shape), # type: ignore[has-type]
args=(normalized_shape, weight, None, eps), # type: ignore[has-type]
)
yield opinfo_core.SampleInput(
make_arg(input_shape), # type: ignore[has-type]
args=(normalized_shape, None, None, eps), # type: ignore[has-type]
)


OP_DB: List[opinfo_core.OpInfo] = [
opinfo_core.OpInfo(
"convolution",
Expand All @@ -184,4 +222,16 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
skips=(),
supports_out=False,
),
opinfo_core.OpInfo(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this on top of "nn.functional.conv3d" for alphabetical order.

"layer_norm",
aliases=("layer_norm",),
aten_name="layer_norm",
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
sample_inputs_func=sample_inputs_layer_norm,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_nondet_tol=common_utils.GRADCHECK_NONDET_TOL,
skips=(),
supports_out=False,
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def _where_input_wrangler(
"convolution": core_ops.aten_convolution,
"empty_like": core_ops.aten_empty_like,
"index_select": core_ops.aten_index_select,
"layer_norm": core_ops.aten_layer_norm,
"native_layer_norm": core_ops.aten_native_layer_norm,
"new_empty": core_ops.aten_new_empty,
"new_empty_strided": core_ops.aten_new_empty_strided,
Expand Down