|
13 | 13 |
|
14 | 14 | from typing import Optional, Sequence
|
15 | 15 |
|
| 16 | +from onnxscript import BOOL, FLOAT, INT64 |
| 17 | +from onnxscript.function_libs.torch_lib.registration import torch_op |
| 18 | +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat |
| 19 | +from onnxscript.onnx_opset import opset18 as op |
16 | 20 | from onnxscript.onnx_types import TensorType
|
17 | 21 |
|
18 | 22 |
|
@@ -305,13 +309,78 @@ def aten_linalg_vecdot(x: TensorType, y: TensorType, dim: int = -1) -> TensorTyp
|
305 | 309 | raise NotImplementedError()
|
306 | 310 |
|
307 | 311 |
|
| 312 | +@torch_op("aten::linalg_vector_norm", trace_only=True) |
308 | 313 | def aten_linalg_vector_norm(
|
309 |
| - self: TensorType, |
310 |
| - ord: float = 2, |
| 314 | + self: TFloat, |
| 315 | + ord: float = 2.0, |
311 | 316 | dim: Optional[int] = None,
|
312 | 317 | keepdim: bool = False,
|
313 |
| - dtype: Optional[int] = None, |
314 |
| -) -> TensorType: |
| 318 | + dtype: int = -1, |
| 319 | +) -> TFloat: |
315 | 320 | """linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
|
316 | 321 |
|
317 |
| - raise NotImplementedError() |
| 322 | + if dtype != -1: |
| 323 | + self = op.Cast(self, to=dtype) |
| 324 | + if dim is None or (isinstance(dim, tuple) and len(dim) == 0): |
| 325 | + self = op.Reshape(self, op.Constant(value_ints=[-1])) |
| 326 | + keepdim = False |
| 327 | + return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim) |
| 328 | + else: |
| 329 | + return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim) |
| 330 | + |
| 331 | + |
| 332 | +@torch_op("aten::linalg_vector_norm", private=True) |
| 333 | +def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat: |
| 334 | + self_rank = op.Size(op.Shape(self)) |
| 335 | + if self_rank == 0: |
| 336 | + self = op.Unsqueeze(self, axes=[0]) |
| 337 | + |
| 338 | + self = op.Abs(self) |
| 339 | + ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT |
| 340 | + if op.IsInf(ord, detect_negative=0, detect_positive=1): |
| 341 | + result = op.ReduceMax(self, keepdims=keepdim) |
| 342 | + elif op.IsInf(ord, detect_negative=1, detect_positive=0): |
| 343 | + result = op.ReduceMin(self, keepdims=keepdim) |
| 344 | + elif ord == 0.0: # sum(x!=0) means count non-zero elements |
| 345 | + self_bool = op.Cast(self, to=BOOL.dtype) |
| 346 | + self_0_1 = op.CastLike(self_bool, self) |
| 347 | + result = op.ReduceSum(self_0_1, keepdims=False) |
| 348 | + else: |
| 349 | + ord_float = op.CastLike(ord, self) |
| 350 | + self_pow = op.Pow(self, ord_float) |
| 351 | + result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float)) |
| 352 | + |
| 353 | + if self_rank == 0: |
| 354 | + result = op.Squeeze(result) |
| 355 | + |
| 356 | + return result |
| 357 | + |
| 358 | + |
| 359 | +@torch_op("aten::linalg_vector_norm", private=True) |
| 360 | +def _aten_linalg_vector_norm_onnx( |
| 361 | + self: TFloat, ord: float, dim: INT64, keepdim: bool |
| 362 | +) -> TFloat: |
| 363 | + self_rank = op.Size(op.Shape(self)) |
| 364 | + if self_rank == 0: |
| 365 | + self = op.Unsqueeze(self, axes=[0]) |
| 366 | + |
| 367 | + dim = op.Reshape(dim, op.Constant(value_ints=[-1])) |
| 368 | + self = op.Abs(self) |
| 369 | + ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT |
| 370 | + if op.IsInf(ord, detect_negative=0, detect_positive=1): |
| 371 | + result = op.ReduceMax(self, dim, keepdims=keepdim) |
| 372 | + elif op.IsInf(ord, detect_negative=1, detect_positive=0): |
| 373 | + result = op.ReduceMin(self, dim, keepdims=keepdim) |
| 374 | + elif ord == 0.0: # sum(x!=0) means count non-zero elements |
| 375 | + self_bool = op.Cast(self, to=BOOL.dtype) |
| 376 | + self_0_1 = op.CastLike(self_bool, self) |
| 377 | + result = op.ReduceSum(self_0_1, dim, keepdims=keepdim) |
| 378 | + else: |
| 379 | + ord_float = op.CastLike(ord, self) |
| 380 | + self_pow = op.Pow(self, ord_float) |
| 381 | + result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float)) |
| 382 | + |
| 383 | + if self_rank == 0: |
| 384 | + result = op.Squeeze(result) |
| 385 | + |
| 386 | + return result |
0 commit comments