Skip to content

add ops: ge, le, maxmium, minmium, softmax #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,10 +1980,11 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_ge(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::ge")
def aten_ge(self: TReal, other: TReal) -> BOOL:
# ge.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Greater(self, other)


def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
Expand Down Expand Up @@ -2537,10 +2538,11 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_le(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::le")
def aten_le(self: TReal, other: TReal) -> BOOL:
# le.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Less(self, other)


def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
Expand Down Expand Up @@ -2925,10 +2927,11 @@ def aten_max_pool3d(
raise NotImplementedError()


def aten_maximum(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::maximum")
def aten_maximum(self: TReal, other: TReal) -> TReal:
# maximum(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Max(self, other)


def aten_mean(self: TensorType, dtype: Optional[int] = None) -> TensorType:
Expand All @@ -2955,10 +2958,11 @@ def aten_min(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_minimum(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::minimum")
def aten_minimum(self: TReal, other: TReal) -> TReal:
# minimum(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Min(self, other)


def aten_miopen_batch_norm(
Expand Down
15 changes: 12 additions & 3 deletions onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,21 @@ def aten_special_sinc(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::softmax")
def aten_special_softmax(
self: TensorType, dim: int, dtype: Optional[int] = None
) -> TensorType:
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
# special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor

raise NotImplementedError()
self_is_scalar = op.Size(op.Shape(self)) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)

return result


def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType:
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ def _log_softmax_input_wrangler(
return args, kwargs


def _softmax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs["dim"] = args.pop()
return args, kwargs


def _topk_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -267,9 +274,11 @@ def _topk_input_wrangler(
"fmod": core_ops.aten_fmod,
"full": (core_ops.aten_full, _full_input_wrangler),
"full_like": core_ops.aten_full_like,
"ge": core_ops.aten_ge,
"gt": core_ops.aten_gt,
"isinf": core_ops.aten_isinf,
"log": core_ops.aten_log,
"le": core_ops.aten_le,
"log10": core_ops.aten_log10,
"log1p": core_ops.aten_log1p,
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
Expand All @@ -281,6 +290,8 @@ def _topk_input_wrangler(
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
"maximum": core_ops.aten_maximum,
"minimum": core_ops.aten_minimum,
"mm": core_ops.aten_mm,
"mul": core_ops.aten_mul,
"ne": core_ops.aten_ne,
Expand Down Expand Up @@ -316,6 +327,7 @@ def _topk_input_wrangler(
"sin": core_ops.aten_sin,
"sinh": core_ops.aten_sinh,
"slice": core_ops.aten_slice,
"softmax": (special_ops.aten_special_softmax, _softmax_input_wrangler),
"sqrt": core_ops.aten_sqrt,
"sub": core_ops.aten_sub,
"t": core_ops.aten_t,
Expand Down