Skip to content

feat(atenlib): ops 7/n #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,6 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor

raise NotImplementedError()


def aten_adaptive_max_pool1d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)

raise NotImplementedError()


@torch_op("aten::add")
def aten_add(self: TReal, other: TReal, alpha: float = 1) -> TReal:
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
Expand Down Expand Up @@ -198,20 +184,20 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
raise NotImplementedError()


def aten_amax(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
) -> TensorType:
# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Make dim optional, keepdim bool
return op.ReduceMax(self, dim, keepdims=keepdim)


def aten_amin(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
) -> TensorType:
# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Make dim optional, keepdim bool
return op.ReduceMin(self, dim, keepdims=keepdim)


def aten_aminmax(
Expand Down Expand Up @@ -4181,10 +4167,11 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


def aten_rsub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
@torch_op("aten::rsub")
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

raise NotImplementedError()
return op.Sub(other, op.Mul(self, alpha))


def aten_scalar_tensor(s: float) -> TensorType:
Expand Down
78 changes: 71 additions & 7 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,65 @@
from onnxscript.onnx_types import TensorType


def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
@torch_op("aten::aten_adaptive_avg_pool1d")
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor

# assert output_size == [1]
# TODO(justinchuby): Specify input constraints

if op.Size(op.Shape(self)) == 2:
# Unbatched case
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
pooled = op.GlobalAveragePool(self)
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
else:
result = op.GlobalAveragePool(self)

return result


@torch_op("aten::aten_adaptive_avg_pool2d")
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
# adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor

raise NotImplementedError()
# assert output_size == [1, 1]
# TODO(justinchuby): Specify input constraints

if op.Size(op.Shape(self)) == 3:
# Unbatched case
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
pooled = op.GlobalAveragePool(self)
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
else:
result = op.GlobalAveragePool(self)

return result


def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType:
@torch_op("aten::aten_adaptive_avg_pool3d")
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
# adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor

# assert output_size == [1, 1, 1]
# TODO(justinchuby): Specify input constraints

if op.Size(op.Shape(self)) == 4:
# Unbatched case
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
pooled = op.GlobalAveragePool(self)
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
else:
result = op.GlobalAveragePool(self)

return result


def aten_adaptive_max_pool1d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)

raise NotImplementedError()


Expand Down Expand Up @@ -1162,15 +1212,29 @@ def aten_upsample_nearest1d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest2d")
def aten_upsample_nearest2d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
# upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor

raise NotImplementedError()
self_shape = op.Shape(self)
batch_channel = self_shape[:2] # type: ignore[index]
output_size = op.Concat(batch_channel, size, axis=0)

# TODO(justinchuby): Conditionally use scales

return op.Resize(
self,
None,
None,
output_size,
mode="nearest",
coordinate_transformation_mode="asymmetric",
)


def aten_upsample_nearest2d_backward(
Expand Down
Loading