Skip to content

Add permute and pow ops and fix attribute value issue. #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 14, 2023
Merged
12 changes: 10 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3708,10 +3708,11 @@ def aten_pdist(self: TensorType, p: float = 2) -> TensorType:
raise NotImplementedError()


def aten_permute(self: TensorType, dims: Sequence[int]) -> TensorType:
@torch_op("aten::permute")
def aten_permute(self: TTensor, dims: Sequence[int]) -> TTensor:
# permute(Tensor(a) self, int[] dims) -> Tensor(a)

raise NotImplementedError()
return op.Transpose(self, perm=dims)


def aten_permute_copy(self: TensorType, dims: Sequence[int]) -> TensorType:
Expand Down Expand Up @@ -3781,6 +3782,13 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::pow")
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
# pow(Tensor self, Tensor exponent) -> Tensor

return op.Pow(self, exponent)


def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
# prelu(Tensor self, Tensor weight) -> Tensor

Expand Down
12 changes: 12 additions & 0 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ def _topk_input_wrangler(
"nonzero": core_ops.aten_nonzero,
"ones_like": core_ops.aten_ones_like,
"ones": core_ops.aten_ones,
"permute": core_ops.aten_permute,
"pow": core_ops.aten_pow,
"reciprocal": core_ops.aten_reciprocal,
"remainder": core_ops.aten_remainder,
"repeat": core_ops.aten_repeat,
Expand Down Expand Up @@ -452,6 +454,16 @@ def _topk_input_wrangler(
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
skip(
"permute",
matcher=lambda sample: len(list(filter(lambda v: v < 0, sample.args[0]))) > 0,
reason="Negative value in perm is not supported",
),
skip(
"permute",
matcher=lambda sample: len(sample.args[0]) == 0,
reason="Empty perm is not supported",
),
skip(
"slice",
# kwargs {dim, start, end, step} is empty, we cannot give the default value
Expand Down
11 changes: 10 additions & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,19 @@ def adapt_kwargs(self, kwargs):
)
return kwargs, closure

def _convert_kwargs_to_numpy(self, kwargs: dict[str, Any]) -> dict[str, Any]:
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v
if isinstance(v, tensor.Tensor):
new_kwargs[k] = v.value

return new_kwargs

def __call__(self, *args, **kwargs):
from onnxscript import evaluator # pylint: disable=import-outside-toplevel

return evaluator.eval(self.opschema, args, kwargs)
return evaluator.eval(self.opschema, args, self._convert_kwargs_to_numpy(kwargs))


@dataclasses.dataclass(repr=False, eq=False)
Expand Down