Skip to content

Commit 8a6f406

Browse files
authored
Merge branch 'main' into new_ops_1
2 parents 37e1f11 + dfaf174 commit 8a6f406

File tree

9 files changed

+168
-65
lines changed

9 files changed

+168
-65
lines changed

azure-pipelines.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ steps:
2525
python -m pip install -r requirements-dev.txt
2626
displayName: 'Install dependencies'
2727

28-
# TODO(#249): Fix tests for onnx 1.13
2928
- script: |
3029
if [ '$(onnx.standard)' == '1' ]
3130
then

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
186186
raise NotImplementedError()
187187

188188

189-
# @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
189+
@torch_op("aten::amax")
190190
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
191191
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
192192

193193
# TODO(justinchuby): Make dim optional, keepdim bool
194194
return op.ReduceMax(self, dim, keepdims=keepdim)
195195

196196

197-
# @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
197+
@torch_op("aten::amin")
198198
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
199199
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
200200

@@ -710,10 +710,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
710710
raise NotImplementedError()
711711

712712

713-
def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType:
713+
@torch_op("aten::cat", trace_only=True)
714+
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
714715
# cat(Tensor[] tensors, int dim=0) -> Tensor
715716

716-
raise NotImplementedError()
717+
num_of_input = len(tensors) # len() function not support yet
718+
a = op.SequenceEmpty()
719+
for i in range(num_of_input):
720+
a = op.SequenceInsert(a, tensors[i])
721+
return op.ConcatFromSequence(a, axis=dim)
717722

718723

719724
def aten_ccol_indices(self: TensorType) -> TensorType:
@@ -1506,16 +1511,15 @@ def aten_einsum(
15061511
raise NotImplementedError()
15071512

15081513

1514+
@torch_op("aten::embedding")
15091515
def aten_embedding(
1510-
weight: TensorType,
1511-
indices: TensorType,
1512-
padding_idx: int = -1,
1513-
scale_grad_by_freq: bool = False,
1514-
sparse: bool = False,
1515-
) -> TensorType:
1516+
weight: TTensor,
1517+
indices: TTensor,
1518+
**_,
1519+
) -> TTensor:
15161520
# embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
15171521

1518-
raise NotImplementedError()
1522+
return op.Gather(weight, indices)
15191523

15201524

15211525
def aten_embedding_backward(
@@ -1570,10 +1574,29 @@ def aten_embedding_sparse_backward(
15701574
raise NotImplementedError()
15711575

15721576

1573-
def aten_empty_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
1577+
@torch_op("aten::empty")
1578+
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
1579+
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
1580+
1581+
# using Zeros to simulate np.empty()
1582+
size = op.Cast(size, to=INT64.dtype)
1583+
zero = op.Constant(value_float=0)
1584+
zero = op.Cast(zero, to=dtype)
1585+
1586+
return op.Expand(zero, size)
1587+
1588+
1589+
@torch_op("aten::empty_like")
1590+
def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
15741591
# empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
15751592

1576-
raise NotImplementedError()
1593+
shape = op.Shape(self)
1594+
if dtype == -1:
1595+
zero = op.CastLike(0, self)
1596+
else:
1597+
zero = op.Cast(0, to=dtype)
1598+
1599+
return op.Expand(zero, shape)
15771600

15781601

15791602
def aten_empty_quantized(
@@ -1957,10 +1980,11 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
19571980
raise NotImplementedError()
19581981

19591982

1960-
def aten_ge(self: TensorType, other: TensorType) -> TensorType:
1983+
@torch_op("aten::ge")
1984+
def aten_ge(self: TReal, other: TReal) -> BOOL:
19611985
# ge.Tensor(Tensor self, Tensor other) -> Tensor
19621986

1963-
raise NotImplementedError()
1987+
return op.Greater(self, other)
19641988

19651989

19661990
def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
@@ -2514,10 +2538,11 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
25142538
raise NotImplementedError()
25152539

25162540

2517-
def aten_le(self: TensorType, other: TensorType) -> TensorType:
2541+
@torch_op("aten::le")
2542+
def aten_le(self: TReal, other: TReal) -> BOOL:
25182543
# le.Tensor(Tensor self, Tensor other) -> Tensor
25192544

2520-
raise NotImplementedError()
2545+
return op.Less(self, other)
25212546

25222547

25232548
def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
@@ -2680,7 +2705,7 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
26802705
raise NotImplementedError()
26812706

26822707

2683-
@torch_op("aten::logsumexp", trace_only=True) # FIXME(#249): Script when CI uses onnx 1.13
2708+
@torch_op("aten::logsumexp")
26842709
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
26852710
# logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
26862711

@@ -2902,10 +2927,11 @@ def aten_max_pool3d(
29022927
raise NotImplementedError()
29032928

29042929

2905-
def aten_maximum(self: TensorType, other: TensorType) -> TensorType:
2930+
@torch_op("aten::maximum")
2931+
def aten_maximum(self: TReal, other: TReal) -> TReal:
29062932
# maximum(Tensor self, Tensor other) -> Tensor
29072933

2908-
raise NotImplementedError()
2934+
return op.Max(self, other)
29092935

29102936

29112937
def aten_mean(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@@ -2932,10 +2958,11 @@ def aten_min(self: TensorType) -> TensorType:
29322958
raise NotImplementedError()
29332959

29342960

2935-
def aten_minimum(self: TensorType, other: TensorType) -> TensorType:
2961+
@torch_op("aten::minimum")
2962+
def aten_minimum(self: TReal, other: TReal) -> TReal:
29362963
# minimum(Tensor self, Tensor other) -> Tensor
29372964

2938-
raise NotImplementedError()
2965+
return op.Min(self, other)
29392966

29402967

29412968
def aten_miopen_batch_norm(
@@ -4393,16 +4420,30 @@ def aten_sinh(self: TFloat) -> TFloat:
43934420
return op.Sinh(self)
43944421

43954422

4423+
@torch_op("aten::slice")
43964424
def aten_slice(
4397-
self: TensorType,
4425+
self: TTensor,
43984426
dim: int = 0,
43994427
start: Optional[INT64] = None,
44004428
end: Optional[INT64] = None,
44014429
step: INT64 = 1,
4402-
) -> TensorType:
4430+
) -> TTensor:
44034431
# slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
44044432

4405-
raise NotImplementedError()
4433+
# TODO: using OptionalHasElement() to check start/end value
4434+
start = op.Cast(start, to=INT64.dtype)
4435+
start = op.Reshape(start, op.Constant(value_ints=[-1]))
4436+
4437+
end = op.Cast(end, to=INT64.dtype)
4438+
end = op.Reshape(end, op.Constant(value_ints=[-1]))
4439+
4440+
dim = op.Cast(dim, to=INT64.dtype)
4441+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
4442+
4443+
step = op.Cast(step, to=INT64.dtype)
4444+
step = op.Reshape(step, op.Constant(value_ints=[-1]))
4445+
4446+
return op.Slice(self, start, end, dim, step)
44064447

44074448

44084449
def aten_slice_backward(

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from typing import Optional, Sequence
1515

16+
from onnxscript import FLOAT
1617
from onnxscript.function_libs.torch_aten.registration import torch_op
1718
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
1819
from onnxscript.onnx_opset import opset18 as op
@@ -205,10 +206,20 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
205206
raise NotImplementedError()
206207

207208

208-
def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType:
209+
@torch_op("aten::log_softmax")
210+
def aten_special_log_softmax(
211+
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
212+
) -> TFloatOrBFloat16:
209213
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
210214

211-
raise NotImplementedError()
215+
self_is_scalar = op.Size(op.Shape(self)) == 0
216+
if self_is_scalar:
217+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
218+
result = op.LogSoftmax(self, axis=dim)
219+
result = op.Cast(result, to=dtype)
220+
if self_is_scalar: # squeeze to scalar due to input is scalar
221+
result = op.Squeeze(result)
222+
return result
212223

213224

214225
def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType:
@@ -327,12 +338,21 @@ def aten_special_sinc(self: TensorType) -> TensorType:
327338
raise NotImplementedError()
328339

329340

341+
@torch_op("aten::softmax")
330342
def aten_special_softmax(
331-
self: TensorType, dim: int, dtype: Optional[int] = None
332-
) -> TensorType:
343+
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
344+
) -> TFloatOrBFloat16:
333345
# special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
334346

335-
raise NotImplementedError()
347+
self_is_scalar = op.Size(op.Shape(self)) == 0
348+
if self_is_scalar:
349+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
350+
result = op.Softmax(self, axis=dim)
351+
result = op.Cast(result, to=dtype)
352+
if self_is_scalar: # squeeze to scalar due to input is scalar
353+
result = op.Squeeze(result)
354+
355+
return result
336356

337357

338358
def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType:

0 commit comments

Comments
 (0)