Skip to content

Commit 241260f

Browse files
AddOp(linalg_vector_norm) | feat(torchlib) (#908)
Also updated the should_skip_xfail logic in test to account for data types. ## Notes Have to skip one kind of test case: ord=6, dtype=float16. In Pytorch: ``` >>> b tensor([[2.3730, 0.9316, 0.6240, 6.1523, 0.1758], [5.3984, 7.9375, 4.9062, 1.8809, 6.1016], [4.0234, 8.2344, 3.2695, 0.8701, 1.3447]], dtype=torch.float16) >>> la.vector_norm(a, dim=0, ord=6) tensor([5.5483, 9.0838, 4.9754, 6.1532, 6.1017]) >>> la.vector_norm(b, dim=0, ord=6) tensor([5.5469, inf, 4.9727, 6.1523, 6.0977], dtype=torch.float16) >>> ``` But in ORT, the result is: ``` tensor([5.5469, 9.08, 4.9727, 6.1523, 6.0977], dtype=loat16) ``` the second element ```9.08``` should be ```inf```. It works in eager mode, failed in FullGraph mode only. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent f51abd2 commit 241260f

File tree

3 files changed

+100
-7
lines changed

3 files changed

+100
-7
lines changed

onnxscript/function_libs/torch_lib/ops/linalg.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
from typing import Optional, Sequence
1515

16+
from onnxscript import BOOL, FLOAT, INT64
17+
from onnxscript.function_libs.torch_lib.registration import torch_op
18+
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
19+
from onnxscript.onnx_opset import opset18 as op
1620
from onnxscript.onnx_types import TensorType
1721

1822

@@ -305,13 +309,78 @@ def aten_linalg_vecdot(x: TensorType, y: TensorType, dim: int = -1) -> TensorTyp
305309
raise NotImplementedError()
306310

307311

312+
@torch_op("aten::linalg_vector_norm", trace_only=True)
308313
def aten_linalg_vector_norm(
309-
self: TensorType,
310-
ord: float = 2,
314+
self: TFloat,
315+
ord: float = 2.0,
311316
dim: Optional[int] = None,
312317
keepdim: bool = False,
313-
dtype: Optional[int] = None,
314-
) -> TensorType:
318+
dtype: int = -1,
319+
) -> TFloat:
315320
"""linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
316321

317-
raise NotImplementedError()
322+
if dtype != -1:
323+
self = op.Cast(self, to=dtype)
324+
if dim is None or (isinstance(dim, tuple) and len(dim) == 0):
325+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
326+
keepdim = False
327+
return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim)
328+
else:
329+
return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim)
330+
331+
332+
@torch_op("aten::linalg_vector_norm", private=True)
333+
def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat:
334+
self_rank = op.Size(op.Shape(self))
335+
if self_rank == 0:
336+
self = op.Unsqueeze(self, axes=[0])
337+
338+
self = op.Abs(self)
339+
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
340+
if op.IsInf(ord, detect_negative=0, detect_positive=1):
341+
result = op.ReduceMax(self, keepdims=keepdim)
342+
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
343+
result = op.ReduceMin(self, keepdims=keepdim)
344+
elif ord == 0.0: # sum(x!=0) means count non-zero elements
345+
self_bool = op.Cast(self, to=BOOL.dtype)
346+
self_0_1 = op.CastLike(self_bool, self)
347+
result = op.ReduceSum(self_0_1, keepdims=False)
348+
else:
349+
ord_float = op.CastLike(ord, self)
350+
self_pow = op.Pow(self, ord_float)
351+
result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float))
352+
353+
if self_rank == 0:
354+
result = op.Squeeze(result)
355+
356+
return result
357+
358+
359+
@torch_op("aten::linalg_vector_norm", private=True)
360+
def _aten_linalg_vector_norm_onnx(
361+
self: TFloat, ord: float, dim: INT64, keepdim: bool
362+
) -> TFloat:
363+
self_rank = op.Size(op.Shape(self))
364+
if self_rank == 0:
365+
self = op.Unsqueeze(self, axes=[0])
366+
367+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
368+
self = op.Abs(self)
369+
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
370+
if op.IsInf(ord, detect_negative=0, detect_positive=1):
371+
result = op.ReduceMax(self, dim, keepdims=keepdim)
372+
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
373+
result = op.ReduceMin(self, dim, keepdims=keepdim)
374+
elif ord == 0.0: # sum(x!=0) means count non-zero elements
375+
self_bool = op.Cast(self, to=BOOL.dtype)
376+
self_0_1 = op.CastLike(self_bool, self)
377+
result = op.ReduceSum(self_0_1, dim, keepdims=keepdim)
378+
else:
379+
ord_float = op.CastLike(ord, self)
380+
self_pow = op.Pow(self, ord_float)
381+
result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float))
382+
383+
if self_rank == 0:
384+
result = op.Squeeze(result)
385+
386+
return result

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
5858

5959

6060
def _should_skip_xfail_test_sample(
61-
op_name: str, sample
61+
op_name: str, sample, dtype: torch.dtype
6262
) -> Tuple[Optional[str], Optional[str]]:
6363
"""Returns a reason if a test sample should be skipped."""
6464
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
@@ -67,6 +67,9 @@ def _should_skip_xfail_test_sample(
6767
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
6868
if decorator_meta.op_name == op_name:
6969
assert decorator_meta.matcher is not None, "Matcher must be defined"
70+
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
71+
# Not applicable for this dtype
72+
continue
7073
if decorator_meta.matcher(sample):
7174
return decorator_meta.test_behavior, decorator_meta.reason
7275
return None, None
@@ -184,7 +187,7 @@ def run_test_output_match(
184187
),
185188
kwargs=repr(cpu_sample.kwargs),
186189
):
187-
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample)
190+
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)
188191

189192
with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
190193
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from typing_extensions import Self
4747

4848
from onnxscript.function_libs.torch_lib.ops import core as core_ops
49+
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
4950
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
5051
from onnxscript.function_libs.torch_lib.ops import special as special_ops
5152
from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common
@@ -270,6 +271,15 @@ def _grid_sample_input_wrangler(
270271
return args, kwargs
271272

272273

274+
def _linalg_vector_norm_input_wrangler(
275+
args: list[Any], kwargs: dict[str, Any]
276+
) -> tuple[list[Any], dict[str, Any]]:
277+
# Make the dims as tensor
278+
if "dim" in kwargs:
279+
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
280+
return args, kwargs
281+
282+
273283
def _max_pool_input_wrangler(
274284
args: list[Any], kwargs: dict[str, Any]
275285
) -> tuple[list[Any], dict[str, Any]]:
@@ -630,6 +640,17 @@ def _where_input_wrangler(
630640
TorchLibOpInfo("isnan", core_ops.aten_isnan),
631641
TorchLibOpInfo("isneginf", core_ops.aten_isneginf),
632642
TorchLibOpInfo("isposinf", core_ops.aten_isposinf),
643+
TorchLibOpInfo(
644+
"linalg.vector_norm",
645+
linalg_ops.aten_linalg_vector_norm,
646+
trace_only=True,
647+
tolerance={torch.float16: (2e-3, 2e-3)},
648+
input_wrangler=_linalg_vector_norm_input_wrangler,
649+
).skip(
650+
matcher=lambda sample: sample.kwargs.get("ord") == 6,
651+
dtypes=[torch.float16],
652+
reason="ORT returns a more accurate value for float16 with ord=6 (expected=Inf, actual=9.48).",
653+
),
633654
TorchLibOpInfo(
634655
"linspace",
635656
core_ops.aten_linspace,

0 commit comments

Comments
 (0)