Skip to content

add empty/empty_like/log_softmax/cat function #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
89bdbb0
add aten::new_empty function
xiaowuhu Dec 25, 2022
3c92572
Update core.py
xiaowuhu Dec 26, 2022
0850518
Update core.py
xiaowuhu Dec 26, 2022
f0e3d00
Update core.py
xiaowuhu Dec 26, 2022
1496f3a
Update core.py
xiaowuhu Dec 26, 2022
8cce23b
Update core.py
xiaowuhu Dec 26, 2022
a75d3aa
Update core.py
xiaowuhu Dec 26, 2022
169a532
Update ops_correctness_test.py
xiaowuhu Dec 26, 2022
690f7bd
update files
xiaowuhu Dec 26, 2022
38c410c
Update core.py
xiaowuhu Dec 27, 2022
7ca5c2c
remove
xiaowuhu Dec 27, 2022
d9917e1
Update core.py
xiaowuhu Dec 27, 2022
9c00791
Update core.py
xiaowuhu Dec 27, 2022
e0ae716
update
xiaowuhu Dec 27, 2022
524836e
update
xiaowuhu Dec 27, 2022
f491426
Update core.py
xiaowuhu Dec 27, 2022
013ee8e
fix bug
xiaowuhu Dec 28, 2022
080ccd7
Update core.py
xiaowuhu Dec 28, 2022
817db54
Update ops_correctness_test.py
xiaowuhu Dec 28, 2022
00733d8
Update ops_correctness_test.py
xiaowuhu Dec 28, 2022
caa523f
Update ops_correctness_test.py
xiaowuhu Dec 28, 2022
c82e423
fix lint
xiaowuhu Dec 28, 2022
9667365
Update core.py
xiaowuhu Dec 28, 2022
1439404
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome
xiaowuhu Jan 9, 2023
59d6ae1
fix bug
xiaowuhu Jan 9, 2023
40ed73b
fix bug
xiaowuhu Jan 9, 2023
1e747db
Update core.py
xiaowuhu Jan 9, 2023
b2b3b52
testing code, for draft
xiaowuhu Jan 9, 2023
d750f4b
Update core.py
xiaowuhu Jan 10, 2023
a8ddb2b
Merge branch 'main' into xiaowu/trySome
xiaowuhu Jan 10, 2023
40d5e0b
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome
xiaowuhu Jan 10, 2023
3eb5f69
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome
xiaowuhu Jan 10, 2023
e0beabd
add ops
xiaowuhu Jan 10, 2023
a205583
Update core.py
xiaowuhu Jan 10, 2023
1e64438
Update core.py
xiaowuhu Jan 10, 2023
3602b8a
Update core.py
xiaowuhu Jan 10, 2023
f833f06
fix bug
xiaowuhu Jan 11, 2023
c507698
Update ops_correctness_test.py
xiaowuhu Jan 11, 2023
48efe33
Merge branch 'main' into xiaowu/trySome
xiaowuhu Jan 11, 2023
e63f9a8
fix lint
xiaowuhu Jan 11, 2023
1018cd1
fix comments
xiaowuhu Jan 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType:
@torch_op("aten::cat", trace_only=True)
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
# cat(Tensor[] tensors, int dim=0) -> Tensor

raise NotImplementedError()
num_of_input = len(tensors) # len() function not support yet
a = op.SequenceEmpty()
for i in range(num_of_input):
a = op.SequenceInsert(a, tensors[i])
return op.ConcatFromSequence(a, axis=dim)


def aten_ccol_indices(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1570,10 +1575,29 @@ def aten_embedding_sparse_backward(
raise NotImplementedError()


def aten_empty_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::empty")
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

# using Zeros to simulate np.empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0)
zero = op.Cast(zero, to=dtype)

return op.Expand(zero, size)


@torch_op("aten::empty_like")
def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
# empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
shape = op.Shape(self)
if dtype == -1:
zero = op.CastLike(0, self)
else:
zero = op.Cast(0, to=dtype)

return op.Expand(zero, shape)


def aten_empty_quantized(
Expand Down
15 changes: 13 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from typing import Optional, Sequence

from onnxscript import FLOAT
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
from onnxscript.onnx_opset import opset18 as op
Expand Down Expand Up @@ -205,10 +206,20 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType:
@torch_op("aten::log_softmax")
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor

raise NotImplementedError()
self_is_scalar = op.Size(op.Shape(self)) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)
return result


def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType:
Expand Down
15 changes: 14 additions & 1 deletion onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def _logcumsumexp_input_wrangler(
return args, kwargs


def _log_softmax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs["dim"] = args.pop()
return args, kwargs


def _topk_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -238,6 +245,7 @@ def _topk_input_wrangler(
"atan": core_ops.aten_atan,
"atanh": core_ops.aten_atanh,
"bmm": core_ops.aten_bmm,
"cat": core_ops.aten_cat,
"ceil": core_ops.aten_ceil,
"clamp_max": core_ops.aten_clamp_max,
"clamp_min": core_ops.aten_clamp_min,
Expand All @@ -247,6 +255,8 @@ def _topk_input_wrangler(
"cosh": core_ops.aten_cosh,
"div": core_ops.aten_div,
"dot": core_ops.aten_dot,
"empty": core_ops.aten_empty,
"empty_like": core_ops.aten_empty_like,
"eq": core_ops.aten_eq,
"equal": core_ops.aten_equal,
"exp": core_ops.aten_exp,
Expand All @@ -268,6 +278,7 @@ def _topk_input_wrangler(
"logcumsumexp": core_ops.aten_logcumsumexp,
"logdet": core_ops.aten_logdet,
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
"mm": core_ops.aten_mm,
Expand Down Expand Up @@ -326,7 +337,9 @@ def _topk_input_wrangler(
EXPECTED_SKIPS_OR_FAILS = (
xfail("amax", reason="ONNX Runtime 1.13 does not support ReduceMax-18"),
xfail("amin", reason="ONNX Runtime 1.13 does not support ReduceMin-18"),
skip("clamp", reason="enable when onnxscript supports optional inputs"),
skip("clamp", reason="Enable when onnxscript supports optional inputs"),
skip("empty", reason="Using zeros to simulate empty"),
skip("empty_like", reason="Using zeros_like to simulate empty_like"),
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
xfail(
Expand Down