Skip to content

feat(atenlib): add ops(copy, fill, empty_strided, native_dropout) #432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,10 +1392,14 @@ def aten_convolution_overrideable(
raise NotImplementedError()


def aten_copy(self: TensorType, src: TensorType, non_blocking: bool = False) -> TensorType:
@torch_op("aten::copy")
def aten_copy(
self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument
) -> TTensor:
# copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor

raise NotImplementedError()
self = op.Identity(src)
return self


def aten_copysign(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -1942,10 +1946,17 @@ def aten_empty_quantized(
raise NotImplementedError()


def aten_empty_strided(size: INT64, stride: INT64) -> TensorType:
@torch_op("aten::empty_strided")
def aten_empty_strided(
size: INT64, stride: INT64 # pylint: disable=unused-argument
) -> TTensor: # type: ignore[type-var]
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
# using Zeros to simulate empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)

return op.Expand(zero, size)


@torch_op("aten::eq")
Expand Down Expand Up @@ -2166,10 +2177,15 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()


def aten_fill(self: TensorType, value: TensorType) -> TensorType:
@torch_op("aten::fill")
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
# fill.Tensor(Tensor self, Tensor value) -> Tensor

raise NotImplementedError()
# after fill, the self Tensor should keep origianl type
shape = op.Shape(self)
expanded = op.Expand(value, shape)
result = op.CastLike(expanded, self)
return result


def aten_fix(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -3792,12 +3808,12 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
raise NotImplementedError()


def aten_native_dropout(
input: TensorType, p: float, train: Optional[bool]
) -> tuple[TensorType, TensorType]:
@torch_op("aten::native_dropout")
def aten_native_dropout(input: TTensor, p: float, train: bool = True) -> Tuple[TTensor, BOOL]:
# native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)

raise NotImplementedError()
result, mask = op.Dropout(input, p, train)
return result, mask


def aten_native_dropout_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,21 @@ def _where_input_wrangler(
"clamp_max": core_ops.aten_clamp_max,
"clamp_min": core_ops.aten_clamp_min,
"clone": core_ops.aten_clone,
# "copy": core_ops.aten_copy, # copy is not in OPS_DB
"cos": core_ops.aten_cos,
"cosh": core_ops.aten_cosh,
# "detach": core_ops.aten_detach, # detach is not in OP-TEST-DB
"div": core_ops.aten_div,
"dot": core_ops.aten_dot,
"empty": core_ops.aten_empty,
# "empty_strided": core_ops.aten_empty_strided, # empty_strided is not in OPS_DB
"eq": core_ops.aten_eq,
"equal": core_ops.aten_equal,
"exp": core_ops.aten_exp,
"exp2": core_ops.aten_exp2,
"expand": core_ops.aten_expand,
"erf": core_ops.aten_erf,
"fill": core_ops.aten_fill,
"fmod": core_ops.aten_fmod,
"full": (core_ops.aten_full, _full_input_wrangler),
"full_like": core_ops.aten_full_like,
Expand All @@ -300,6 +303,7 @@ def _where_input_wrangler(
"minimum": core_ops.aten_minimum,
"mm": core_ops.aten_mm,
"mul": core_ops.aten_mul,
# "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB
"ne": core_ops.aten_ne,
"neg": core_ops.aten_neg,
"new_full": core_ops.aten_new_full,
Expand Down