Skip to content

Commit b448bc0

Browse files
committed
Simplify
1 parent 0b29830 commit b448bc0

File tree

1 file changed

+5
-22
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-22
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from __future__ import annotations
1313

1414
import math
15-
from typing import Any, Optional, Sequence, Tuple, TypeVar, Union
15+
from typing import Any, Optional, Sequence, Tuple, Union
1616

1717
from onnxscript import (
1818
BFLOAT16,
@@ -50,19 +50,6 @@
5050
from onnxscript.onnx_opset import opset18 as op
5151
from onnxscript.onnx_types import TensorType
5252

53-
TRealUnlessFloat32 = TypeVar(
54-
"TRealUnlessFloat32",
55-
bound=Union[
56-
BFLOAT16,
57-
FLOAT16,
58-
DOUBLE,
59-
INT8,
60-
INT16,
61-
INT32,
62-
INT64,
63-
],
64-
)
65-
6653
_INT64_MAX = 9223372036854775807
6754
_INT64_MIN = -9223372036854775808
6855
_MATH_PI = math.pi
@@ -235,12 +222,8 @@ def aten_addcmul(
235222

236223
@torch_op("aten::addmm")
237224
def aten_addmm(
238-
self: TRealUnlessFloat32,
239-
mat1: TRealUnlessFloat32,
240-
mat2: TRealUnlessFloat32,
241-
beta: float = 1.0,
242-
alpha: float = 1.0,
243-
) -> TRealUnlessFloat32:
225+
self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0
226+
) -> TInt:
244227
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
245228

246229
mat1_mat2 = op.MatMul(mat1, mat2)
@@ -251,8 +234,8 @@ def aten_addmm(
251234

252235
@torch_op("aten::addmm")
253236
def aten_addmm_gemm(
254-
self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0
255-
) -> FLOAT:
237+
self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0
238+
) -> TFloat:
256239
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
257240

258241
# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul

0 commit comments

Comments
 (0)