Skip to content

Commit 1da3b9c

Browse files
authored
[torchlib] Fix layer norm dtype (#2100)
Fix layer norm dtype mismatch errors Fixes #2099
1 parent 882a442 commit 1da3b9c

File tree

1 file changed

+3
-20
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+3
-20
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
UINT32,
3232
UINT64,
3333
graph,
34+
ir,
3435
)
3536
from onnxscript.function_libs.torch_lib.ops import common as common_ops
3637
from onnxscript.function_libs.torch_lib.registration import torch_op
@@ -4749,28 +4750,10 @@ def aten_layer_norm(
47494750
start_axis = -len(normalized_shape)
47504751

47514752
if weight is None:
4752-
one = op.Constant(value_float=1.0)
4753+
one = op.Constant(value=ir.tensor(1, dtype=input.dtype))
47534754
weight = op.Expand(one, op.Shape(input, start=start_axis))
47544755

4755-
if bias is None:
4756-
zero = op.Constant(value_float=0.0)
4757-
bias = op.Expand(zero, op.Shape(input, start=start_axis))
4758-
4759-
return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps)
4760-
4761-
4762-
@torch_op("aten::layer_norm", private=True)
4763-
def _aten_layer_norm_onnx(
4764-
input: TReal,
4765-
weight: TReal,
4766-
bias: TReal,
4767-
axis: int,
4768-
eps: float = 1e-05,
4769-
) -> TReal:
4770-
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
4771-
4772-
# TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
4773-
result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
4756+
result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps)
47744757
return result
47754758

47764759

0 commit comments

Comments
 (0)