Skip to content

feat(atenlib): logarithmic ops; test aten::full #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 37 additions & 26 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
raise NotImplementedError()


# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
# @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

# TODO(justinchuby): Make dim optional, keepdim bool
return op.ReduceMax(self, dim, keepdims=keepdim)


# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
# @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

Expand Down Expand Up @@ -2575,52 +2575,68 @@ def aten_linspace(start: float, end: float, steps: int) -> TensorType:
raise NotImplementedError()


def aten_log(self: TensorType) -> TensorType:
@torch_op("log")
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
# log(Tensor self) -> Tensor

raise NotImplementedError()
return op.Log(self)


def aten_log10(self: TensorType) -> TensorType:
@torch_op("aten::log10")
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
# log10(Tensor self) -> Tensor

raise NotImplementedError()
return op.Div(op.Log(self), op.Log(10.0))


def aten_log1p(self: TensorType) -> TensorType:
@torch_op("aten::log1p")
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
# log1p(Tensor self) -> Tensor

raise NotImplementedError()
return op.Log(op.Add(self, 1.0))


def aten_log2(self: TensorType) -> TensorType:
@torch_op("aten::log2")
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
# log2(Tensor self) -> Tensor

raise NotImplementedError()
return op.Div(op.Log(self), op.Log(2.0))


def aten_logaddexp(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::logaddexp")
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
# logaddexp(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Log(op.Add(op.Exp(self), op.Exp(other)))


def aten_logaddexp2(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::logaddexp2")
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
# logaddexp2(Tensor self, Tensor other) -> Tensor
summation = op.Add(op.Pow(2.0, self), op.Pow(2.0, other))

raise NotImplementedError()
return op.Div(op.Log(summation), op.Log(2.0))


def aten_logcumsumexp(self: TensorType, dim: int) -> TensorType:
@torch_op("aten::logcumsumexp")
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: INT64) -> TFloatOrBFloat16:
# logcumsumexp(Tensor self, int dim) -> Tensor

raise NotImplementedError()
if op.Size(op.Shape(self)) == 0:
# A scalar
result = self
else:
# FIXME(justinchuby): Ensure numerical stability
result = op.Log(op.CumSum(op.Exp(self), dim))

return result


def aten_logdet(self: TensorType) -> TensorType:
@torch_op("aten::logdet")
def aten_logdet(self: TFloat) -> TFloat:
# logdet(Tensor self) -> Tensor

raise NotImplementedError()
return op.Log(op.Det(self))


@torch_op("aten::logical_and")
Expand Down Expand Up @@ -2663,10 +2679,11 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
raise NotImplementedError()


def aten_logsumexp(self: TensorType, dim: Sequence[int], keepdim: bool = False) -> TensorType:
@torch_op("aten::logsumexp", trace_only=True) # FIXME(#249): Script when CI uses onnx 1.13
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
# logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor

raise NotImplementedError()
return op.ReduceLogSumExp(self, dim, keepdims=keepdim)


def aten_lshift(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -5035,12 +5052,6 @@ def aten_where(self: TTensor, condition: BOOL, other: TTensor) -> TTensor:
return op.Where(condition, self, other)


def aten_xlogy(self: TensorType, other: TensorType) -> TensorType:
# xlogy.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()


def aten_xor(self: TensorType, other: TensorType) -> TensorType:
# __xor__.Tensor(Tensor self, Tensor other) -> Tensor

Expand Down
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,11 @@ def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) ->
return result


def aten_log_sigmoid(self: TensorType) -> TensorType:
@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
# log_sigmoid(Tensor self) -> Tensor

raise NotImplementedError()
return op.Log(op.Sigmoid(self))


def aten_log_sigmoid_backward(
Expand Down
19 changes: 17 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from typing import Optional, Sequence

from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -344,10 +347,22 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_special_xlogy(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::xlogy")
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
# special_xlogy(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
# out := {
# NaN if other == NaN
# 0 if self == 0
# self * log(other) otherwise
# }

nans = op.IsNaN(other)
zeros = op.Equal(self, 0)
xlogy = op.Mul(self, op.Log(other))
xlogy_with_nans = op.Where(nans, other, xlogy)
return op.Where(zeros, self, xlogy_with_nans)


def aten_special_zeta(self: TensorType, other: TensorType) -> TensorType:
Expand Down
83 changes: 60 additions & 23 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import onnxscript
from onnxscript.function_libs.torch_aten.ops import core as core_ops
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops
from onnxscript.function_libs.torch_aten.ops import special as special_ops

T = TypeVar("T")

Expand Down Expand Up @@ -161,20 +162,39 @@ def duplicate_opinfo(opinfos: list[opinfo_core.OpInfo], name: str, new_names: tu
# Modify this section ##########################################################


def _amax_amin_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
def _amax_amin_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
kwargs["dim"] = None
return kwargs
return args, kwargs


def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
def _full_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Remove the self argument
args.pop(0)
return args, kwargs


def _upsample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
del kwargs["scale_factor"]
if "size" in kwargs:
kwargs["size"] = np.array(kwargs["size"])
return kwargs
return args, kwargs


def _logcumsumexp_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs["keepdim"] = args.pop()
return args, kwargs


# Ops to be tested for numerical consistency between onnx and pytorch
Expand All @@ -185,16 +205,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
| Callable[..., Any]
| tuple[
onnxscript.OnnxFunction | Callable[..., Any],
Callable[[dict[str, Any]], dict[str, Any]],
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
],
] = {
"abs": core_ops.aten_abs,
"acos": core_ops.aten_acos,
"acosh": core_ops.aten_acosh,
"add": core_ops.aten_add,
"addmm": core_ops.aten_addmm,
"amax": (core_ops.aten_amax, _amax_amin_kwargs_wrangler),
"amin": (core_ops.aten_amin, _amax_amin_kwargs_wrangler),
"amax": (core_ops.aten_amax, _amax_amin_input_wrangler),
"amin": (core_ops.aten_amin, _amax_amin_input_wrangler),
"arange_start_step": core_ops.aten_arange_start_step,
"arange_start": core_ops.aten_arange_start,
"arange": core_ops.aten_arange,
Expand All @@ -219,11 +239,20 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"expand": core_ops.aten_expand,
"erf": core_ops.aten_erf,
"fmod": core_ops.aten_fmod,
# TODO(justinchuby): Test aten::full
"full": (core_ops.aten_full, _full_input_wrangler),
"full_like": core_ops.aten_full_like,
"gt": core_ops.aten_gt,
"index_select": core_ops.aten_index_select,
"isinf": core_ops.aten_isinf,
"log": core_ops.aten_log,
"log10": core_ops.aten_log10,
"log1p": core_ops.aten_log1p,
"log2": core_ops.aten_log2,
"logaddexp": core_ops.aten_logaddexp,
"logaddexp2": core_ops.aten_logaddexp2,
"logcumsumexp": core_ops.aten_logcumsumexp,
"logdet": core_ops.aten_logdet,
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
"mm": core_ops.aten_mm,
Expand All @@ -237,12 +266,13 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
"nn.functional.linear": nn_ops.aten_linear,
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,
"nn.functional.relu": nn_ops.aten_relu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"nn.functional.upsample_nearest2d": (
nn_ops.aten_upsample_nearest2d,
_upsample_kwargs_wrangler,
_upsample_input_wrangler,
),
"nonzero": core_ops.aten_nonzero,
"ones_like": core_ops.aten_ones_like,
Expand All @@ -267,6 +297,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"unsqueeze": core_ops.aten_unsqueeze,
"view": core_ops.aten_view,
"where": core_ops.aten_where,
"xlogy": special_ops.aten_special_xlogy,
"zeros": core_ops.aten_zeros,
"zeros_like": core_ops.aten_zeros_like,
}
Expand All @@ -276,7 +307,9 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
EXPECTED_SKIPS_OR_FAILS = (
xfail("amax", reason="ONNX Runtime 1.13 does not support ReduceMax-18"),
xfail("amin", reason="ONNX Runtime 1.13 does not support ReduceMin-18"),
skip("clamp", reason="Enable when onnxscript supports optional inputs"),
skip("clamp", reason="enable when onnxscript supports optional inputs"),
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
xfail(
"nn.functional.linear",
reason="ONNX Runtime thinks the graph is invalid",
Expand Down Expand Up @@ -358,23 +391,25 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:

duplicate_opinfo(
OPS_DB,
"nn.functional.upsample_nearest",
"arange",
(
"nn.functional.upsample_nearest1d",
"nn.functional.upsample_nearest2d",
"nn.functional.upsample_nearest3d",
"arange_start",
"arange_start_step",
),
)

duplicate_opinfo(
OPS_DB,
"arange",
"nn.functional.upsample_nearest",
(
"arange_start",
"arange_start_step",
"nn.functional.upsample_nearest1d",
"nn.functional.upsample_nearest2d",
"nn.functional.upsample_nearest3d",
),
)

duplicate_opinfo(OPS_DB, "new_full", ("full",))


# END OF SECTION TO MODIFY #####################################################

Expand Down Expand Up @@ -477,13 +512,13 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
)

onnx_function_and_wrangler = OPINFO_FUNCTION_MAPPING[op.name]
kwarg_wrangler = None
input_wrangler = None
if isinstance(onnx_function_and_wrangler, tuple):
# Obtain the kwarg_wrangler that manipulates the OpInfo inputs
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, kwarg_wrangler = onnx_function_and_wrangler
onnx_function, input_wrangler = onnx_function_and_wrangler
else:
assert callable(onnx_function_and_wrangler)
onnx_function = onnx_function_and_wrangler
Expand All @@ -503,8 +538,8 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
continue
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
if kwarg_wrangler:
kwargs_onnx = kwarg_wrangler(kwargs_onnx)
if input_wrangler:
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
function_output = onnx_function(*input_onnx, **kwargs_onnx)

Expand All @@ -524,7 +559,9 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
torch.testing.assert_close(
torch.tensor(function_output),
torch.tensor(torch_output),
torch_output
if isinstance(torch_output, torch.Tensor)
else torch.tensor(torch_output),
rtol=rtol,
atol=atol,
)
Expand Down