Skip to content

AddOp(linalg_vector_norm) | feat(torchlib) #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from typing import Optional, Sequence

from onnxscript import BOOL, FLOAT, INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -305,13 +309,78 @@ def aten_linalg_vecdot(x: TensorType, y: TensorType, dim: int = -1) -> TensorTyp
raise NotImplementedError()


@torch_op("aten::linalg_vector_norm", trace_only=True)
def aten_linalg_vector_norm(
self: TensorType,
ord: float = 2,
self: TFloat,
ord: float = 2.0,
dim: Optional[int] = None,
keepdim: bool = False,
dtype: Optional[int] = None,
) -> TensorType:
dtype: int = -1,
) -> TFloat:
"""linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
if dtype != -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby I am confused. Do we want the diaptcher to take care of dtype overload?

The dtype scenario should follow aten_argmax or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do. But there's also the dim: Optional[int] = None, which can be split, in which case we will have 4 variants of the same function. Since the function is currently trace_only not because of the dtype situation here, I thought it's not worth it having 4 copies of this thing

self = op.Cast(self, to=dtype)
if dim is None or (isinstance(dim, tuple) and len(dim) == 0):
self = op.Reshape(self, op.Constant(value_ints=[-1]))
keepdim = False
return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim)
else:
return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim)


@torch_op("aten::linalg_vector_norm", private=True)
def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat:
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Unsqueeze(self, axes=[0])

self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
result = op.ReduceMin(self, keepdims=keepdim)
elif ord == 0.0: # sum(x!=0) means count non-zero elements
self_bool = op.Cast(self, to=BOOL.dtype)
self_0_1 = op.CastLike(self_bool, self)
result = op.ReduceSum(self_0_1, keepdims=False)
else:
ord_float = op.CastLike(ord, self)
self_pow = op.Pow(self, ord_float)
result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float))

if self_rank == 0:
result = op.Squeeze(result)

return result


@torch_op("aten::linalg_vector_norm", private=True)
def _aten_linalg_vector_norm_onnx(
self: TFloat, ord: float, dim: INT64, keepdim: bool
) -> TFloat:
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Unsqueeze(self, axes=[0])

dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, dim, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
result = op.ReduceMin(self, dim, keepdims=keepdim)
elif ord == 0.0: # sum(x!=0) means count non-zero elements
self_bool = op.Cast(self, to=BOOL.dtype)
self_0_1 = op.CastLike(self_bool, self)
result = op.ReduceSum(self_0_1, dim, keepdims=keepdim)
else:
ord_float = op.CastLike(ord, self)
self_pow = op.Pow(self, ord_float)
result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float))

if self_rank == 0:
result = op.Squeeze(result)

return result
7 changes: 5 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:


def _should_skip_xfail_test_sample(
op_name: str, sample
op_name: str, sample, dtype: torch.dtype
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
Expand All @@ -67,6 +67,9 @@ def _should_skip_xfail_test_sample(
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
Expand Down Expand Up @@ -184,7 +187,7 @@ def run_test_output_match(
),
kwargs=repr(cpu_sample.kwargs),
):
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample)
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)

with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
Expand Down
21 changes: 21 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from typing_extensions import Self

from onnxscript.function_libs.torch_lib.ops import core as core_ops
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
from onnxscript.function_libs.torch_lib.ops import special as special_ops
from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common
Expand Down Expand Up @@ -270,6 +271,15 @@ def _grid_sample_input_wrangler(
return args, kwargs


def _linalg_vector_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Make the dims as tensor
if "dim" in kwargs:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs


def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -630,6 +640,17 @@ def _where_input_wrangler(
TorchLibOpInfo("isnan", core_ops.aten_isnan),
TorchLibOpInfo("isneginf", core_ops.aten_isneginf),
TorchLibOpInfo("isposinf", core_ops.aten_isposinf),
TorchLibOpInfo(
"linalg.vector_norm",
linalg_ops.aten_linalg_vector_norm,
trace_only=True,
tolerance={torch.float16: (2e-3, 2e-3)},
input_wrangler=_linalg_vector_norm_input_wrangler,
).skip(
matcher=lambda sample: sample.kwargs.get("ord") == 6,
dtypes=[torch.float16],
reason="ORT returns a more accurate value for float16 with ord=6 (expected=Inf, actual=9.48).",
),
TorchLibOpInfo(
"linspace",
core_ops.aten_linspace,
Expand Down