Skip to content

Commit 42be235

Browse files
authored
add empty/empty_like/log_softmax/cat function (#264)
Signed-off-by: xiaowuhu <[email protected]> add ops: empty, empty_like, log_softmax, cat
1 parent 1b81421 commit 42be235

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
710710
raise NotImplementedError()
711711

712712

713-
def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType:
713+
@torch_op("aten::cat", trace_only=True)
714+
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
714715
# cat(Tensor[] tensors, int dim=0) -> Tensor
715716

716-
raise NotImplementedError()
717+
num_of_input = len(tensors) # len() function not support yet
718+
a = op.SequenceEmpty()
719+
for i in range(num_of_input):
720+
a = op.SequenceInsert(a, tensors[i])
721+
return op.ConcatFromSequence(a, axis=dim)
717722

718723

719724
def aten_ccol_indices(self: TensorType) -> TensorType:
@@ -1570,10 +1575,29 @@ def aten_embedding_sparse_backward(
15701575
raise NotImplementedError()
15711576

15721577

1573-
def aten_empty_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
1578+
@torch_op("aten::empty")
1579+
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
1580+
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
1581+
1582+
# using Zeros to simulate np.empty()
1583+
size = op.Cast(size, to=INT64.dtype)
1584+
zero = op.Constant(value_float=0)
1585+
zero = op.Cast(zero, to=dtype)
1586+
1587+
return op.Expand(zero, size)
1588+
1589+
1590+
@torch_op("aten::empty_like")
1591+
def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
15741592
# empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
15751593

1576-
raise NotImplementedError()
1594+
shape = op.Shape(self)
1595+
if dtype == -1:
1596+
zero = op.CastLike(0, self)
1597+
else:
1598+
zero = op.Cast(0, to=dtype)
1599+
1600+
return op.Expand(zero, shape)
15771601

15781602

15791603
def aten_empty_quantized(

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from typing import Optional, Sequence
1515

16+
from onnxscript import FLOAT
1617
from onnxscript.function_libs.torch_aten.registration import torch_op
1718
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
1819
from onnxscript.onnx_opset import opset18 as op
@@ -205,10 +206,20 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
205206
raise NotImplementedError()
206207

207208

208-
def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType:
209+
@torch_op("aten::log_softmax")
210+
def aten_special_log_softmax(
211+
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
212+
) -> TFloatOrBFloat16:
209213
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
210214

211-
raise NotImplementedError()
215+
self_is_scalar = op.Size(op.Shape(self)) == 0
216+
if self_is_scalar:
217+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
218+
result = op.LogSoftmax(self, axis=dim)
219+
result = op.Cast(result, to=dtype)
220+
if self_is_scalar: # squeeze to scalar due to input is scalar
221+
result = op.Squeeze(result)
222+
return result
212223

213224

214225
def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ def _logcumsumexp_input_wrangler(
198198
return args, kwargs
199199

200200

201+
def _log_softmax_input_wrangler(
202+
args: list[Any], kwargs: dict[str, Any]
203+
) -> tuple[list[Any], dict[str, Any]]:
204+
kwargs["dim"] = args.pop()
205+
return args, kwargs
206+
207+
201208
def _topk_input_wrangler(
202209
args: list[Any], kwargs: dict[str, Any]
203210
) -> tuple[list[Any], dict[str, Any]]:
@@ -238,6 +245,7 @@ def _topk_input_wrangler(
238245
"atan": core_ops.aten_atan,
239246
"atanh": core_ops.aten_atanh,
240247
"bmm": core_ops.aten_bmm,
248+
"cat": core_ops.aten_cat,
241249
"ceil": core_ops.aten_ceil,
242250
"clamp_max": core_ops.aten_clamp_max,
243251
"clamp_min": core_ops.aten_clamp_min,
@@ -247,6 +255,8 @@ def _topk_input_wrangler(
247255
"cosh": core_ops.aten_cosh,
248256
"div": core_ops.aten_div,
249257
"dot": core_ops.aten_dot,
258+
"empty": core_ops.aten_empty,
259+
"empty_like": core_ops.aten_empty_like,
250260
"eq": core_ops.aten_eq,
251261
"equal": core_ops.aten_equal,
252262
"exp": core_ops.aten_exp,
@@ -268,6 +278,7 @@ def _topk_input_wrangler(
268278
"logcumsumexp": core_ops.aten_logcumsumexp,
269279
"logdet": core_ops.aten_logdet,
270280
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
281+
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
271282
"lt": core_ops.aten_lt,
272283
"matmul": core_ops.aten_matmul,
273284
"mm": core_ops.aten_mm,
@@ -326,7 +337,9 @@ def _topk_input_wrangler(
326337
EXPECTED_SKIPS_OR_FAILS = (
327338
xfail("amax", reason="ONNX Runtime 1.13 does not support ReduceMax-18"),
328339
xfail("amin", reason="ONNX Runtime 1.13 does not support ReduceMin-18"),
329-
skip("clamp", reason="enable when onnxscript supports optional inputs"),
340+
skip("clamp", reason="Enable when onnxscript supports optional inputs"),
341+
skip("empty", reason="Using zeros to simulate empty"),
342+
skip("empty_like", reason="Using zeros_like to simulate empty_like"),
330343
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
331344
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
332345
xfail(

0 commit comments

Comments
 (0)