Skip to content

Commit bf84f87

Browse files
committed
Merge branch 'main' into titaiwang/expand_support_minus1
2 parents 2133898 + b0d4f24 commit bf84f87

File tree

3 files changed

+137
-44
lines changed

3 files changed

+137
-44
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2895,17 +2895,27 @@ def aten_kthvalue(
28952895
raise NotImplementedError()
28962896

28972897

2898+
@torch_op("aten::layer_norm", trace_only=True)
28982899
def aten_layer_norm(
2899-
input: TensorType,
2900+
input: TReal,
29002901
normalized_shape: Sequence[int],
2901-
weight: Optional[TensorType] = None,
2902-
bias: Optional[TensorType] = None,
2902+
weight: Optional[TReal] = None,
2903+
bias: Optional[TReal] = None,
29032904
eps: float = 1e-05,
2904-
cudnn_enable: bool = True,
2905-
) -> TensorType:
2905+
) -> TReal:
29062906
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
29072907

2908-
raise NotImplementedError()
2908+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
2909+
start_axis = axes_list[0]
2910+
if not op.OptionalHasElement(weight):
2911+
one = op.Constant(value_float=1.0)
2912+
weight = op.Expand(one, op.Shape(input, start=start_axis))
2913+
if not op.OptionalHasElement(bias):
2914+
zero = op.Constant(value_float=0.0)
2915+
bias = op.Expand(zero, op.Shape(input, start=start_axis))
2916+
2917+
result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps)
2918+
return result
29092919

29102920

29112921
def aten_lcm(self: TensorType, other: TensorType) -> TensorType:
@@ -3259,10 +3269,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
32593269
raise NotImplementedError()
32603270

32613271

3262-
def aten_max(self: TensorType) -> TensorType:
3272+
@torch_op("aten::max", trace_only=True)
3273+
def aten_max(
3274+
self: TReal, dim_or_other: Union[TReal, INT64] = None, keepdim: BOOL = None
3275+
) -> TReal:
32633276
"""max(Tensor self) -> Tensor"""
32643277

3265-
raise NotImplementedError()
3278+
self_rank = op.Size(op.Shape(self))
3279+
if self_rank == 0:
3280+
self = op.Reshape(self, op.Constant(value_int=[-1]))
3281+
3282+
output = 1
3283+
3284+
if op.OptionalHasElement(dim_or_other):
3285+
if isinstance(dim_or_other, int):
3286+
if not op.OptionalHasElement(keepdim):
3287+
keepdim = False
3288+
result, indices = _aten_max_with_dim(self, dim_or_other, keepdim)
3289+
output = 2
3290+
else: # dim_or_other is tensor
3291+
result = _aten_max_with_other(self, dim_or_other)
3292+
else:
3293+
result = _aten_max_with_no_dim(self)
3294+
3295+
if self_rank == 0:
3296+
result = op.Squeeze(result)
3297+
3298+
if output == 2:
3299+
if self_rank == 0:
3300+
indices = op.Squeeze(indices) # type: ignore[has-type]
3301+
return result, indices
3302+
return result
3303+
3304+
3305+
@torch_op("aten::max", overload=True)
3306+
def _aten_max_with_no_dim(self: TReal) -> TReal:
3307+
result = op.ReduceMax(self, keepdims=0)
3308+
return result
3309+
3310+
3311+
@torch_op("aten::max", overload=True)
3312+
def _aten_max_with_other(self: TReal, other: TReal) -> TReal:
3313+
result = op.Max(self, other)
3314+
return result
3315+
3316+
3317+
@torch_op("aten::max", overload=True)
3318+
# def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
3319+
def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool):
3320+
dims = op.Reshape(dim, op.Constant(value_int=[-1]))
3321+
result = op.ReduceMax(self, dims, keepdims=keepdim)
3322+
indices = op.ArgMax(self, axis=dim, keepdims=keepdim)
3323+
return result, indices
32663324

32673325

32683326
def aten_max_pool1d(
@@ -3920,12 +3978,13 @@ def aten_native_layer_norm(
39203978
# where D is the dimension of normalized_shape. For example, if normalized_shape is
39213979
# (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed
39223980
# over the last 2 dimensions of the input (i.e. input.mean((-2, -1))).
3923-
axes = [-i for i in range(len(normalized_shape), 0, -1)]
3924-
if weight is None:
3981+
axes_list = [-i for i in range(len(normalized_shape), 0, -1)]
3982+
axes = op.Constant(value_ints=axes_list)
3983+
if not op.OptionalHasElement(weight):
39253984
weight = op.CastLike(1, input)
3926-
if bias is None:
3985+
if not op.OptionalHasElement(bias):
39273986
bias = op.CastLike(0, input)
3928-
return _aten_native_layer_norm_onnx(input, weight, bias, axes=axes, eps=eps)
3987+
return _aten_native_layer_norm_onnx(input, weight, bias, axes, eps)
39293988

39303989

39313990
@torch_op("aten::native_layer_norm", overload=True)
@@ -3938,18 +3997,18 @@ def _aten_native_layer_norm_onnx(
39383997
) -> Tuple[TReal, TReal, TReal]:
39393998

39403999
# FIXME(justinchuby): Use opset18 when it is supported by onnxruntime
3941-
mean = opset17.ReduceMean(input, axes=axes)
3942-
numerator = opset17.Sub(input, mean)
3943-
power_num = opset17.Pow(numerator, 2.0)
3944-
variance = opset17.ReduceMean(power_num, axes=axes)
3945-
variance_eps = opset17.Add(variance, eps)
3946-
denominator = opset17.Sqrt(variance_eps)
3947-
result = opset17.Div(numerator, denominator)
3948-
weight = opset17.CastLike(weight, result)
3949-
result = opset17.Mul(result, weight)
3950-
bias = opset17.CastLike(bias, result)
3951-
result = opset17.Add(result, bias)
3952-
rdenominator = opset17.Reciprocal(denominator)
4000+
mean = op.ReduceMean(input, axes)
4001+
numerator = op.Sub(input, mean)
4002+
power_num = op.Pow(numerator, 2.0)
4003+
variance = op.ReduceMean(power_num, axes)
4004+
variance_eps = op.Add(variance, eps)
4005+
denominator = op.Sqrt(variance_eps)
4006+
result = op.Div(numerator, denominator)
4007+
weight = op.CastLike(weight, result)
4008+
result = op.Mul(result, weight)
4009+
bias = op.CastLike(bias, result)
4010+
result = op.Add(result, bias)
4011+
rdenominator = op.Reciprocal(denominator)
39534012
return result, mean, rdenominator
39544013

39554014

@@ -5055,20 +5114,10 @@ def aten_square(self: TensorType) -> TensorType:
50555114
raise NotImplementedError()
50565115

50575116

5058-
@torch_op("aten::squeeze", trace_only=True)
5059-
def aten_squeeze(self: TTensor, dim: Optional[int] = None) -> TTensor:
5117+
def aten_squeeze(self: TensorType) -> TensorType:
50605118
"""squeeze(Tensor(a) self) -> Tensor(a)"""
50615119

5062-
if op.OptionalHasElement(dim):
5063-
rank = op.Size(op.Shape(self))
5064-
if rank == 0:
5065-
self = op.Reshape(self, op.Constant(value_ints=[-1]))
5066-
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
5067-
result = op.Squeeze(self, dims)
5068-
else:
5069-
result = op.Squeeze(self)
5070-
5071-
return result
5120+
raise NotImplementedError()
50725121

50735122

50745123
def aten_squeeze_copy(self: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/extra_opinfo.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
159159
)
160160

161161

162+
def sample_inputs_layer_norm(
163+
op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument
164+
):
165+
make_arg = functools.partial(
166+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
167+
)
168+
169+
# Ordered as input shape, normalized_shape, eps
170+
cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment]
171+
((1, 2, 3), (1, 2, 3), 0.5),
172+
((2, 2, 3), (2, 3), -0.5),
173+
((1,), (1,), 1e-5),
174+
((1, 2), (2,), 1e-5),
175+
((0, 1), (1,), 1e-5),
176+
)
177+
178+
for input_shape, normalized_shape, eps in cases: # type: ignore[misc]
179+
# Shape of weight and bias should be the same as normalized_shape
180+
weight = make_arg(normalized_shape) # type: ignore[has-type]
181+
bias = make_arg(normalized_shape) # type: ignore[has-type]
182+
yield opinfo_core.SampleInput(
183+
make_arg(input_shape), # type: ignore[has-type]
184+
args=(normalized_shape, weight, bias, eps), # type: ignore[has-type]
185+
)
186+
yield opinfo_core.SampleInput(
187+
make_arg(input_shape), # type: ignore[has-type]
188+
args=(normalized_shape, None, bias, eps), # type: ignore[has-type]
189+
)
190+
yield opinfo_core.SampleInput(
191+
make_arg(input_shape), # type: ignore[has-type]
192+
args=(normalized_shape, weight, None, eps), # type: ignore[has-type]
193+
)
194+
yield opinfo_core.SampleInput(
195+
make_arg(input_shape), # type: ignore[has-type]
196+
args=(normalized_shape, None, None, eps), # type: ignore[has-type]
197+
)
198+
199+
162200
OP_DB: List[opinfo_core.OpInfo] = [
163201
opinfo_core.OpInfo(
164202
"convolution",
@@ -184,4 +222,16 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
184222
skips=(),
185223
supports_out=False,
186224
),
225+
opinfo_core.OpInfo(
226+
"layer_norm",
227+
aliases=("layer_norm",),
228+
aten_name="layer_norm",
229+
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
230+
sample_inputs_func=sample_inputs_layer_norm,
231+
supports_forward_ad=True,
232+
supports_fwgrad_bwgrad=True,
233+
gradcheck_nondet_tol=common_utils.GRADCHECK_NONDET_TOL,
234+
skips=(),
235+
supports_out=False,
236+
),
187237
]

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def _where_input_wrangler(
398398
"convolution": core_ops.aten_convolution,
399399
"empty_like": core_ops.aten_empty_like,
400400
"index_select": core_ops.aten_index_select,
401+
"layer_norm": core_ops.aten_layer_norm,
402+
"max": core_ops.aten_max,
401403
"native_layer_norm": core_ops.aten_native_layer_norm,
402404
"new_empty": core_ops.aten_new_empty,
403405
"new_empty_strided": core_ops.aten_new_empty_strided,
@@ -412,7 +414,6 @@ def _where_input_wrangler(
412414
),
413415
"ones_like": core_ops.aten_ones_like,
414416
"slice": core_ops.aten_slice,
415-
"squeeze": core_ops.aten_squeeze,
416417
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
417418
"transpose": core_ops.aten_transpose,
418419
"zeros_like": core_ops.aten_zeros_like,
@@ -557,13 +558,6 @@ def _where_input_wrangler(
557558
matcher=lambda sample: len(sample.args[0]) == 0,
558559
reason="Empty perm is not supported",
559560
),
560-
skip(
561-
"squeeze",
562-
matcher=lambda sample: len(sample.args) > 0
563-
and len(sample.input.shape) > 0
564-
and sample.input.shape[sample.args[0]] != 1,
565-
reason="Cannot select an axis to squeeze out which has size not equal to one",
566-
),
567561
)
568562

569563
duplicate_opinfo(

0 commit comments

Comments
 (0)