12
12
from __future__ import annotations
13
13
14
14
import math
15
- from typing import Any , Optional , Sequence , Tuple , TypeVar , Union
15
+ from typing import Any , Optional , Sequence , Tuple , Union
16
16
17
17
from onnxscript import (
18
18
BFLOAT16 ,
50
50
from onnxscript .onnx_opset import opset18 as op
51
51
from onnxscript .onnx_types import TensorType
52
52
53
- TRealUnlessFloat32 = TypeVar (
54
- "TRealUnlessFloat32" ,
55
- bound = Union [
56
- BFLOAT16 ,
57
- FLOAT16 ,
58
- DOUBLE ,
59
- INT8 ,
60
- INT16 ,
61
- INT32 ,
62
- INT64 ,
63
- ],
64
- )
65
-
66
53
_INT64_MAX = 9223372036854775807
67
54
_INT64_MIN = - 9223372036854775808
68
55
_MATH_PI = math .pi
@@ -235,12 +222,8 @@ def aten_addcmul(
235
222
236
223
@torch_op ("aten::addmm" )
237
224
def aten_addmm (
238
- self : TRealUnlessFloat32 ,
239
- mat1 : TRealUnlessFloat32 ,
240
- mat2 : TRealUnlessFloat32 ,
241
- beta : float = 1.0 ,
242
- alpha : float = 1.0 ,
243
- ) -> TRealUnlessFloat32 :
225
+ self : TInt , mat1 : TInt , mat2 : TInt , beta : float = 1.0 , alpha : float = 1.0
226
+ ) -> TInt :
244
227
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
245
228
246
229
mat1_mat2 = op .MatMul (mat1 , mat2 )
@@ -251,8 +234,8 @@ def aten_addmm(
251
234
252
235
@torch_op ("aten::addmm" )
253
236
def aten_addmm_gemm (
254
- self : FLOAT , mat1 : FLOAT , mat2 : FLOAT , beta : float = 1.0 , alpha : float = 1.0
255
- ) -> FLOAT :
237
+ self : TFloat , mat1 : TFloat , mat2 : TFloat , beta : float = 1.0 , alpha : float = 1.0
238
+ ) -> TFloat :
256
239
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
257
240
258
241
# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul
0 commit comments