Skip to content

Commit 690ed5d

Browse files
authored
feat(atenlib): add ops(layer_norm) (#459)
1 parent c5ca05b commit 690ed5d

File tree

3 files changed

+88
-22
lines changed

3 files changed

+88
-22
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,17 +2893,27 @@ def aten_kthvalue(
28932893
raise NotImplementedError()
28942894

28952895

2896+
@torch_op("aten::layer_norm", trace_only=True)
28962897
def aten_layer_norm(
2897-
input: TensorType,
2898+
input: TReal,
28982899
normalized_shape: Sequence[int],
2899-
weight: Optional[TensorType] = None,
2900-
bias: Optional[TensorType] = None,
2900+
weight: Optional[TReal] = None,
2901+
bias: Optional[TReal] = None,
29012902
eps: float = 1e-05,
2902-
cudnn_enable: bool = True,
2903-
) -> TensorType:
2903+
) -> TReal:
29042904
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
29052905

2906-
raise NotImplementedError()
2906+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
2907+
start_axis = axes_list[0]
2908+
if not op.OptionalHasElement(weight):
2909+
one = op.Constant(value_float=1.0)
2910+
weight = op.Expand(one, op.Shape(input, start=start_axis))
2911+
if not op.OptionalHasElement(bias):
2912+
zero = op.Constant(value_float=0.0)
2913+
bias = op.Expand(zero, op.Shape(input, start=start_axis))
2914+
2915+
result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps)
2916+
return result
29072917

29082918

29092919
def aten_lcm(self: TensorType, other: TensorType) -> TensorType:
@@ -3966,12 +3976,13 @@ def aten_native_layer_norm(
39663976
# where D is the dimension of normalized_shape. For example, if normalized_shape is
39673977
# (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
39683978
# over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
3969-
axes = [-i for i in range(len(normalized_shape), 0, -1)]
3970-
if weight is None:
3979+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
3980+
axes = op.Constant(value_ints=axes_list)
3981+
if not op.OptionalHasElement(weight):
39713982
weight = op.CastLike(1, input)
3972-
if bias is None:
3983+
if not op.OptionalHasElement(bias):
39733984
bias = op.CastLike(0, input)
3974-
return _aten_native_layer_norm_onnx(input, weight, bias, axes=axes, eps=eps)
3985+
return _aten_native_layer_norm_onnx(input, weight, bias, axes, eps)
39753986

39763987

39773988
@torch_op("aten::native_layer_norm", overload=True)
@@ -3984,18 +3995,18 @@ def _aten_native_layer_norm_onnx(
39843995
) -> Tuple[TReal, TReal, TReal]:
39853996

39863997
# FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
3987-
mean = opset17.ReduceMean(input, axes=axes)
3988-
numerator = opset17.Sub(input, mean)
3989-
power_num = opset17.Pow(numerator, 2.0)
3990-
variance = opset17.ReduceMean(power_num, axes=axes)
3991-
variance_eps = opset17.Add(variance, eps)
3992-
denominator = opset17.Sqrt(variance_eps)
3993-
result = opset17.Div(numerator, denominator)
3994-
weight = opset17.CastLike(weight, result)
3995-
result = opset17.Mul(result, weight)
3996-
bias = opset17.CastLike(bias, result)
3997-
result = opset17.Add(result, bias)
3998-
rdenominator = opset17.Reciprocal(denominator)
3998+
mean = op.ReduceMean(input, axes)
3999+
numerator = op.Sub(input, mean)
4000+
power_num = op.Pow(numerator, 2.0)
4001+
variance = op.ReduceMean(power_num, axes)
4002+
variance_eps = op.Add(variance, eps)
4003+
denominator = op.Sqrt(variance_eps)
4004+
result = op.Div(numerator, denominator)
4005+
weight = op.CastLike(weight, result)
4006+
result = op.Mul(result, weight)
4007+
bias = op.CastLike(bias, result)
4008+
result = op.Add(result, bias)
4009+
rdenominator = op.Reciprocal(denominator)
39994010
return result, mean, rdenominator
40004011

40014012

onnxscript/tests/function_libs/torch_aten/extra_opinfo.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
159159
)
160160

161161

162+
def sample_inputs_layer_norm(
163+
op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument
164+
):
165+
make_arg = functools.partial(
166+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
167+
)
168+
169+
# Ordered as input shape, normalized_shape, eps
170+
cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment]
171+
((1, 2, 3), (1, 2, 3), 0.5),
172+
((2, 2, 3), (2, 3), -0.5),
173+
((1,), (1,), 1e-5),
174+
((1, 2), (2,), 1e-5),
175+
((0, 1), (1,), 1e-5),
176+
)
177+
178+
for input_shape, normalized_shape, eps in cases: # type: ignore[misc]
179+
# Shape of weight and bias should be the same as normalized_shape
180+
weight = make_arg(normalized_shape) # type: ignore[has-type]
181+
bias = make_arg(normalized_shape) # type: ignore[has-type]
182+
yield opinfo_core.SampleInput(
183+
make_arg(input_shape), # type: ignore[has-type]
184+
args=(normalized_shape, weight, bias, eps), # type: ignore[has-type]
185+
)
186+
yield opinfo_core.SampleInput(
187+
make_arg(input_shape), # type: ignore[has-type]
188+
args=(normalized_shape, None, bias, eps), # type: ignore[has-type]
189+
)
190+
yield opinfo_core.SampleInput(
191+
make_arg(input_shape), # type: ignore[has-type]
192+
args=(normalized_shape, weight, None, eps), # type: ignore[has-type]
193+
)
194+
yield opinfo_core.SampleInput(
195+
make_arg(input_shape), # type: ignore[has-type]
196+
args=(normalized_shape, None, None, eps), # type: ignore[has-type]
197+
)
198+
199+
162200
OP_DB: List[opinfo_core.OpInfo] = [
163201
opinfo_core.OpInfo(
164202
"convolution",
@@ -184,4 +222,16 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
184222
skips=(),
185223
supports_out=False,
186224
),
225+
opinfo_core.OpInfo(
226+
"layer_norm",
227+
aliases=("layer_norm",),
228+
aten_name="layer_norm",
229+
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
230+
sample_inputs_func=sample_inputs_layer_norm,
231+
supports_forward_ad=True,
232+
supports_fwgrad_bwgrad=True,
233+
gradcheck_nondet_tol=common_utils.GRADCHECK_NONDET_TOL,
234+
skips=(),
235+
supports_out=False,
236+
),
187237
]

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def _where_input_wrangler(
398398
"convolution": core_ops.aten_convolution,
399399
"empty_like": core_ops.aten_empty_like,
400400
"index_select": core_ops.aten_index_select,
401+
"layer_norm": core_ops.aten_layer_norm,
401402
"max": core_ops.aten_max,
402403
"native_layer_norm": core_ops.aten_native_layer_norm,
403404
"new_empty": core_ops.aten_new_empty,
@@ -747,6 +748,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
747748
inputs=repr(inputs),
748749
kwargs=repr(cpu_sample.kwargs),
749750
):
751+
752+
if i == 5:
753+
print(i)
754+
750755
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
751756
if skip_reason is not None:
752757
# Cannot use self.skip because pytest would skip the entire test

0 commit comments

Comments
 (0)