Skip to content

Commit dcca04f

Browse files
authored
Merge branch 'main' into torchscriptgraph_initializer
2 parents c78058b + 5ba4fc7 commit dcca04f

File tree

3 files changed

+139
-44
lines changed

3 files changed

+139
-44
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 87 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,8 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor:
20392039
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
20402040

20412041
size = op.Cast(size, to=INT64.dtype)
2042+
# To support -1 dim.
2043+
size = op.Abs(size)
20422044
return op.Expand(self, size)
20432045

20442046

@@ -2893,17 +2895,27 @@ def aten_kthvalue(
28932895
raise NotImplementedError()
28942896

28952897

2898+
@torch_op("aten::layer_norm", trace_only=True)
28962899
def aten_layer_norm(
2897-
input: TensorType,
2900+
input: TReal,
28982901
normalized_shape: Sequence[int],
2899-
weight: Optional[TensorType] = None,
2900-
bias: Optional[TensorType] = None,
2902+
weight: Optional[TReal] = None,
2903+
bias: Optional[TReal] = None,
29012904
eps: float = 1e-05,
2902-
cudnn_enable: bool = True,
2903-
) -> TensorType:
2905+
) -> TReal:
29042906
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
29052907

2906-
raise NotImplementedError()
2908+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
2909+
start_axis = axes_list[0]
2910+
if not op.OptionalHasElement(weight):
2911+
one = op.Constant(value_float=1.0)
2912+
weight = op.Expand(one, op.Shape(input, start=start_axis))
2913+
if not op.OptionalHasElement(bias):
2914+
zero = op.Constant(value_float=0.0)
2915+
bias = op.Expand(zero, op.Shape(input, start=start_axis))
2916+
2917+
result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps)
2918+
return result
29072919

29082920

29092921
def aten_lcm(self: TensorType, other: TensorType) -> TensorType:
@@ -3257,10 +3269,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
32573269
raise NotImplementedError()
32583270

32593271

3260-
def aten_max(self: TensorType) -> TensorType:
3272+
@torch_op("aten::max", trace_only=True)
3273+
def aten_max(
3274+
self: TReal, dim_or_other: Union[TReal, INT64] = None, keepdim: BOOL = None
3275+
) -> TReal:
32613276
"""max(Tensor self) -> Tensor"""
32623277

3263-
raise NotImplementedError()
3278+
self_rank = op.Size(op.Shape(self))
3279+
if self_rank == 0:
3280+
self = op.Reshape(self, op.Constant(value_int=[-1]))
3281+
3282+
output = 1
3283+
3284+
if op.OptionalHasElement(dim_or_other):
3285+
if isinstance(dim_or_other, int):
3286+
if not op.OptionalHasElement(keepdim):
3287+
keepdim = False
3288+
result, indices = _aten_max_with_dim(self, dim_or_other, keepdim)
3289+
output = 2
3290+
else: # dim_or_other is tensor
3291+
result = _aten_max_with_other(self, dim_or_other)
3292+
else:
3293+
result = _aten_max_with_no_dim(self)
3294+
3295+
if self_rank == 0:
3296+
result = op.Squeeze(result)
3297+
3298+
if output == 2:
3299+
if self_rank == 0:
3300+
indices = op.Squeeze(indices) # type: ignore[has-type]
3301+
return result, indices
3302+
return result
3303+
3304+
3305+
@torch_op("aten::max", overload=True)
3306+
def _aten_max_with_no_dim(self: TReal) -> TReal:
3307+
result = op.ReduceMax(self, keepdims=0)
3308+
return result
3309+
3310+
3311+
@torch_op("aten::max", overload=True)
3312+
def _aten_max_with_other(self: TReal, other: TReal) -> TReal:
3313+
result = op.Max(self, other)
3314+
return result
3315+
3316+
3317+
@torch_op("aten::max", overload=True)
3318+
# def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
3319+
def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool):
3320+
dims = op.Reshape(dim, op.Constant(value_int=[-1]))
3321+
result = op.ReduceMax(self, dims, keepdims=keepdim)
3322+
indices = op.ArgMax(self, axis=dim, keepdims=keepdim)
3323+
return result, indices
32643324

32653325

32663326
def aten_max_pool1d(
@@ -3918,12 +3978,13 @@ def aten_native_layer_norm(
39183978
# where D is the dimension of normalized_shape. For example, if normalized_shape is
39193979
# (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
39203980
# over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
3921-
axes = [-i for i in range(len(normalized_shape), 0, -1)]
3922-
if weight is None:
3981+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
3982+
axes = op.Constant(value_ints=axes_list)
3983+
if not op.OptionalHasElement(weight):
39233984
weight = op.CastLike(1, input)
3924-
if bias is None:
3985+
if not op.OptionalHasElement(bias):
39253986
bias = op.CastLike(0, input)
3926-
return _aten_native_layer_norm_onnx(input, weight, bias, axes=axes, eps=eps)
3987+
return _aten_native_layer_norm_onnx(input, weight, bias, axes, eps)
39273988

39283989

39293990
@torch_op("aten::native_layer_norm", overload=True)
@@ -3936,18 +3997,18 @@ def _aten_native_layer_norm_onnx(
39363997
) -> Tuple[TReal, TReal, TReal]:
39373998

39383999
# FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
3939-
mean = opset17.ReduceMean(input, axes=axes)
3940-
numerator = opset17.Sub(input, mean)
3941-
power_num = opset17.Pow(numerator, 2.0)
3942-
variance = opset17.ReduceMean(power_num, axes=axes)
3943-
variance_eps = opset17.Add(variance, eps)
3944-
denominator = opset17.Sqrt(variance_eps)
3945-
result = opset17.Div(numerator, denominator)
3946-
weight = opset17.CastLike(weight, result)
3947-
result = opset17.Mul(result, weight)
3948-
bias = opset17.CastLike(bias, result)
3949-
result = opset17.Add(result, bias)
3950-
rdenominator = opset17.Reciprocal(denominator)
4000+
mean = op.ReduceMean(input, axes)
4001+
numerator = op.Sub(input, mean)
4002+
power_num = op.Pow(numerator, 2.0)
4003+
variance = op.ReduceMean(power_num, axes)
4004+
variance_eps = op.Add(variance, eps)
4005+
denominator = op.Sqrt(variance_eps)
4006+
result = op.Div(numerator, denominator)
4007+
weight = op.CastLike(weight, result)
4008+
result = op.Mul(result, weight)
4009+
bias = op.CastLike(bias, result)
4010+
result = op.Add(result, bias)
4011+
rdenominator = op.Reciprocal(denominator)
39514012
return result, mean, rdenominator
39524013

39534014

@@ -5053,20 +5114,10 @@ def aten_square(self: TensorType) -> TensorType:
50535114
raise NotImplementedError()
50545115

50555116

5056-
@torch_op("aten::squeeze", trace_only=True)
5057-
def aten_squeeze(self: TTensor, dim: Optional[int] = None) -> TTensor:
5117+
def aten_squeeze(self: TensorType) -> TensorType:
50585118
"""squeeze(Tensor(a) self) -> Tensor(a)"""
50595119

5060-
if op.OptionalHasElement(dim):
5061-
rank = op.Size(op.Shape(self))
5062-
if rank == 0:
5063-
self = op.Reshape(self, op.Constant(value_ints=[-1]))
5064-
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
5065-
result = op.Squeeze(self, dims)
5066-
else:
5067-
result = op.Squeeze(self)
5068-
5069-
return result
5120+
raise NotImplementedError()
50705121

50715122

50725123
def aten_squeeze_copy(self: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/extra_opinfo.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
159159
)
160160

161161

162+
def sample_inputs_layer_norm(
163+
op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument
164+
):
165+
make_arg = functools.partial(
166+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
167+
)
168+
169+
# Ordered as input shape, normalized_shape, eps
170+
cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment]
171+
((1, 2, 3), (1, 2, 3), 0.5),
172+
((2, 2, 3), (2, 3), -0.5),
173+
((1,), (1,), 1e-5),
174+
((1, 2), (2,), 1e-5),
175+
((0, 1), (1,), 1e-5),
176+
)
177+
178+
for input_shape, normalized_shape, eps in cases: # type: ignore[misc]
179+
# Shape of weight and bias should be the same as normalized_shape
180+
weight = make_arg(normalized_shape) # type: ignore[has-type]
181+
bias = make_arg(normalized_shape) # type: ignore[has-type]
182+
yield opinfo_core.SampleInput(
183+
make_arg(input_shape), # type: ignore[has-type]
184+
args=(normalized_shape, weight, bias, eps), # type: ignore[has-type]
185+
)
186+
yield opinfo_core.SampleInput(
187+
make_arg(input_shape), # type: ignore[has-type]
188+
args=(normalized_shape, None, bias, eps), # type: ignore[has-type]
189+
)
190+
yield opinfo_core.SampleInput(
191+
make_arg(input_shape), # type: ignore[has-type]
192+
args=(normalized_shape, weight, None, eps), # type: ignore[has-type]
193+
)
194+
yield opinfo_core.SampleInput(
195+
make_arg(input_shape), # type: ignore[has-type]
196+
args=(normalized_shape, None, None, eps), # type: ignore[has-type]
197+
)
198+
199+
162200
OP_DB: List[opinfo_core.OpInfo] = [
163201
opinfo_core.OpInfo(
164202
"convolution",
@@ -184,4 +222,16 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
184222
skips=(),
185223
supports_out=False,
186224
),
225+
opinfo_core.OpInfo(
226+
"layer_norm",
227+
aliases=("layer_norm",),
228+
aten_name="layer_norm",
229+
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
230+
sample_inputs_func=sample_inputs_layer_norm,
231+
supports_forward_ad=True,
232+
supports_fwgrad_bwgrad=True,
233+
gradcheck_nondet_tol=common_utils.GRADCHECK_NONDET_TOL,
234+
skips=(),
235+
supports_out=False,
236+
),
187237
]

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def _where_input_wrangler(
398398
"convolution": core_ops.aten_convolution,
399399
"empty_like": core_ops.aten_empty_like,
400400
"index_select": core_ops.aten_index_select,
401+
"layer_norm": core_ops.aten_layer_norm,
402+
"max": core_ops.aten_max,
401403
"native_layer_norm": core_ops.aten_native_layer_norm,
402404
"new_empty": core_ops.aten_new_empty,
403405
"new_empty_strided": core_ops.aten_new_empty_strided,
@@ -412,7 +414,6 @@ def _where_input_wrangler(
412414
),
413415
"ones_like": core_ops.aten_ones_like,
414416
"slice": core_ops.aten_slice,
415-
"squeeze": core_ops.aten_squeeze,
416417
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
417418
"transpose": core_ops.aten_transpose,
418419
"zeros_like": core_ops.aten_zeros_like,
@@ -557,13 +558,6 @@ def _where_input_wrangler(
557558
matcher=lambda sample: len(sample.args[0]) == 0,
558559
reason="Empty perm is not supported",
559560
),
560-
skip(
561-
"squeeze",
562-
matcher=lambda sample: len(sample.args) > 0
563-
and len(sample.input.shape) > 0
564-
and sample.input.shape[sample.args[0]] != 1,
565-
reason="Cannot select an axis to squeeze out which has size not equal to one",
566-
),
567561
)
568562

569563
duplicate_opinfo(

0 commit comments

Comments
 (0)