18
18
from typing import Any , Optional , Sequence
19
19
20
20
from onnxscript import BOOL , INT64
21
+ from onnxscript .function_libs .torch_aten .registration import torch_op
21
22
from onnxscript .onnx_opset import opset18 as op
22
23
from onnxscript .onnx_types import TensorType
23
24
24
25
26
+ @torch_op ("aten::abs" )
25
27
def aten_abs (self ):
26
28
# abs(Tensor self) -> Tensor
27
29
28
30
return op .Abs (self )
29
31
30
32
33
+ @torch_op ("aten::acos" )
31
34
def aten_acos (self ):
32
35
# acos(Tensor self) -> Tensor
33
36
34
37
return op .Acos (self )
35
38
36
39
40
+ @torch_op ("aten::acosh" )
37
41
def aten_acosh (self ):
38
42
# acosh(Tensor self) -> Tensor
39
43
@@ -54,6 +58,7 @@ def aten_adaptive_max_pool1d(
54
58
raise NotImplementedError ()
55
59
56
60
61
+ @torch_op ("aten::add" )
57
62
def aten_add (self , other , alpha : float = 1 ) -> TensorType :
58
63
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
59
64
if alpha != 1 :
@@ -85,6 +90,7 @@ def aten_addcmul(
85
90
raise NotImplementedError ()
86
91
87
92
93
+ @torch_op ("aten::addmm" )
88
94
def aten_addmm (self , mat1 , mat2 , beta : float = 1 , alpha : float = 1 ):
89
95
# addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
90
96
@@ -332,18 +338,21 @@ def aten_as_strided_scatter(
332
338
raise NotImplementedError ()
333
339
334
340
341
+ @torch_op ("aten::asin" )
335
342
def aten_asin (self ):
336
343
# asin(Tensor self) -> Tensor
337
344
338
345
return op .Asin (self )
339
346
340
347
348
+ @torch_op ("aten::asinh" )
341
349
def aten_asinh (self ):
342
350
# asinh(Tensor self) -> Tensor
343
351
344
352
return op .Asinh (self )
345
353
346
354
355
+ @torch_op ("aten::atan" )
347
356
def aten_atan (self ):
348
357
# atan(Tensor self) -> Tensor
349
358
@@ -356,6 +365,7 @@ def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
356
365
raise NotImplementedError ()
357
366
358
367
368
+ @torch_op ("aten::atanh" )
359
369
def aten_atanh (self ):
360
370
# atanh(Tensor self) -> Tensor
361
371
@@ -606,6 +616,7 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
606
616
raise NotImplementedError ()
607
617
608
618
619
+ @torch_op ("aten::bmm" )
609
620
def aten_bmm (self , mat2 ):
610
621
# bmm(Tensor self, Tensor mat2) -> Tensor
611
622
@@ -670,6 +681,7 @@ def aten_cdist(
670
681
raise NotImplementedError ()
671
682
672
683
684
+ @torch_op ("aten::ceil" )
673
685
def aten_ceil (self ):
674
686
# ceil(Tensor self) -> Tensor
675
687
@@ -728,6 +740,7 @@ def aten_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType:
728
740
raise NotImplementedError ()
729
741
730
742
743
+ @torch_op ("aten::clamp" )
731
744
def aten_clamp (self : TensorType , min_ = None , max_ = None ) -> TensorType :
732
745
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
733
746
@@ -752,19 +765,22 @@ def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
752
765
return clamped
753
766
754
767
768
+ @torch_op ("aten::clamp_max.Scalar" , overload = True )
755
769
def aten_clamp_max_scalar (self , max_ ):
756
770
# clamp_max(Tensor self, Scalar max) -> Tensor
757
771
758
772
max_ = op .CastLike (max_ , self )
759
773
return op .Clip (self , None , max_ )
760
774
761
775
776
+ @torch_op ("aten::clamp_max.Tensor" )
762
777
def aten_clamp_max_tensor (self , max_ ):
763
- # clamp_max(Tensor self, Scalar max) -> Tensor
778
+ # clamp_max(Tensor self, Tensor max) -> Tensor
764
779
765
780
return op .Min (self , max_ )
766
781
767
782
783
+ @torch_op ("aten::clamp_min.Scalar" , overload = True )
768
784
def aten_clamp_min_scalar (self , min_ ):
769
785
# clamp_min(Tensor self, Scalar min) -> Tensor
770
786
# NOTE: min_ is a rank 0 tensor.
@@ -773,6 +789,7 @@ def aten_clamp_min_scalar(self, min_):
773
789
return op .Clip (self , min_ , None )
774
790
775
791
792
+ @torch_op ("aten::clamp_min.Tensor" )
776
793
def aten_clamp_min_tensor (self , min_ ):
777
794
# clamp_min(Tensor self, Tensor min) -> Tensor
778
795
# TODO(justinchuby): Specify the type constraints.
@@ -1017,12 +1034,14 @@ def aten_corrcoef(self: TensorType) -> TensorType:
1017
1034
raise NotImplementedError ()
1018
1035
1019
1036
1037
+ @torch_op ("aten::cos" )
1020
1038
def aten_cos (self ):
1021
1039
# cos(Tensor self) -> Tensor
1022
1040
1023
1041
return op .Cos (self )
1024
1042
1025
1043
1044
+ @torch_op ("aten::cosh" )
1026
1045
def aten_cosh (self ):
1027
1046
# cosh(Tensor self) -> Tensor
1028
1047
@@ -1392,6 +1411,7 @@ def aten_divide(self: TensorType, other: TensorType) -> TensorType:
1392
1411
raise NotImplementedError ()
1393
1412
1394
1413
1414
+ @torch_op ("aten::dot" )
1395
1415
def aten_dot (self , tensor ):
1396
1416
# dot(Tensor self, Tensor tensor) -> Tensor
1397
1417
@@ -1532,12 +1552,14 @@ def aten_erfinv(self: TensorType) -> TensorType:
1532
1552
raise NotImplementedError ()
1533
1553
1534
1554
1555
+ @torch_op ("aten::exp" )
1535
1556
def aten_exp (self ):
1536
1557
# exp(Tensor self) -> Tensor
1537
1558
1538
1559
return op .Exp (self )
1539
1560
1540
1561
1562
+ @torch_op ("aten::exp2" )
1541
1563
def aten_exp2 (self ):
1542
1564
# exp2(Tensor self) -> Tensor
1543
1565
@@ -1972,6 +1994,7 @@ def aten_gru_cell(
1972
1994
raise NotImplementedError ()
1973
1995
1974
1996
1997
+ @torch_op ("aten::gt" )
1975
1998
def aten_gt (self , other ):
1976
1999
# gt.Tensor(Tensor self, Tensor other) -> Tensor
1977
2000
@@ -2588,6 +2611,7 @@ def aten_lstm_mps_backward(
2588
2611
raise NotImplementedError ()
2589
2612
2590
2613
2614
+ @torch_op ("aten::lt" )
2591
2615
def aten_lt (self , other ):
2592
2616
# lt.Tensor(Tensor self, Tensor other) -> Tensor
2593
2617
@@ -2663,6 +2687,7 @@ def aten_masked_select_backward(
2663
2687
raise NotImplementedError ()
2664
2688
2665
2689
2690
+ @torch_op ("aten::matmul" )
2666
2691
def aten_matmul (self , other ):
2667
2692
# matmul(Tensor self, Tensor other) -> Tensor
2668
2693
@@ -3063,6 +3088,7 @@ def aten_mkldnn_max_pool3d_backward(
3063
3088
raise NotImplementedError ()
3064
3089
3065
3090
3091
+ @torch_op ("aten::mm" )
3066
3092
def aten_mm (self , mat2 ):
3067
3093
# mm(Tensor self, Tensor mat2) -> Tensor
3068
3094
@@ -3129,12 +3155,14 @@ def aten_msort(self: TensorType) -> TensorType:
3129
3155
raise NotImplementedError ()
3130
3156
3131
3157
3132
- def aten_mul (self , other ) -> TensorType :
3158
+ @torch_op ("aten::mul" )
3159
+ def aten_mul (self , other ):
3133
3160
# mul.Tensor(Tensor self, Tensor other) -> Tensor
3134
3161
3135
3162
return op .Mul (self , other )
3136
3163
3137
3164
3165
+ @torch_op ("aten::mul" , overload = True )
3138
3166
def aten_mul_bool (self : BOOL , other : BOOL ) -> BOOL :
3139
3167
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""
3140
3168
@@ -3447,6 +3475,7 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
3447
3475
raise NotImplementedError ()
3448
3476
3449
3477
3478
+ @torch_op ("aten::ones" )
3450
3479
def aten_ones (size : INT64 , dtype : int = - 1 ) -> TensorType :
3451
3480
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
3452
3481
@@ -3456,6 +3485,7 @@ def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
3456
3485
return op .Expand (one , size ) # type: ignore[arg-type]
3457
3486
3458
3487
3488
+ @torch_op ("aten::ones_like" )
3459
3489
def aten_ones_like (self , dtype : int = - 1 ):
3460
3490
"""ones_like.
3461
3491
@@ -3942,6 +3972,7 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
3942
3972
raise NotImplementedError ()
3943
3973
3944
3974
3975
+ @torch_op ("aten::repeat" )
3945
3976
def aten_repeat (self , repeats : INT64 ):
3946
3977
# repeat(Tensor self, SymInt[] repeats) -> Tensor
3947
3978
@@ -4047,6 +4078,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
4047
4078
raise NotImplementedError ()
4048
4079
4049
4080
4081
+ @torch_op ("aten::round" )
4050
4082
def aten_round (self ):
4051
4083
# round(Tensor self) -> Tensor
4052
4084
@@ -4157,6 +4189,7 @@ def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int)
4157
4189
raise NotImplementedError ()
4158
4190
4159
4191
4192
+ @torch_op ("aten::selu" )
4160
4193
def aten_selu (self ):
4161
4194
# selu(Tensor self) -> Tensor
4162
4195
@@ -4193,12 +4226,14 @@ def aten_signbit(self: TensorType) -> TensorType:
4193
4226
raise NotImplementedError ()
4194
4227
4195
4228
4229
+ @torch_op ("aten::sin" )
4196
4230
def aten_sin (self ):
4197
4231
# sin(Tensor self) -> Tensor
4198
4232
4199
4233
return op .Sin (self )
4200
4234
4201
4235
4236
+ @torch_op ("aten::sinh" )
4202
4237
def aten_sinh (self ):
4203
4238
# sinh(Tensor self) -> Tensor
4204
4239
@@ -4378,6 +4413,7 @@ def aten_stft(
4378
4413
raise NotImplementedError ()
4379
4414
4380
4415
4416
+ @torch_op ("aten::sub" )
4381
4417
def aten_sub (self , other , alpha : float = 1 ) -> TensorType :
4382
4418
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
4383
4419
@@ -4433,6 +4469,7 @@ def aten_symeig(
4433
4469
raise NotImplementedError ()
4434
4470
4435
4471
4472
+ @torch_op ("aten::t" )
4436
4473
def aten_t (self : TensorType ) -> TensorType :
4437
4474
# t(Tensor(a) self) -> Tensor(a)
4438
4475
@@ -4465,12 +4502,14 @@ def aten_take_along_dim(
4465
4502
raise NotImplementedError ()
4466
4503
4467
4504
4505
+ @torch_op ("aten::tan" )
4468
4506
def aten_tan (self ):
4469
4507
# tan(Tensor self) -> Tensor
4470
4508
4471
4509
return op .Tan (self )
4472
4510
4473
4511
4512
+ @torch_op ("aten::tanh" )
4474
4513
def aten_tanh (self ):
4475
4514
# tanh(Tensor self) -> Tensor
4476
4515
@@ -4858,6 +4897,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
4858
4897
raise NotImplementedError ()
4859
4898
4860
4899
4900
+ @torch_op ("aten::zeros" )
4861
4901
def aten_zeros (size , dtype : int = - 1 ):
4862
4902
# zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
4863
4903
@@ -4868,6 +4908,7 @@ def aten_zeros(size, dtype: int = -1):
4868
4908
return op .Expand (zero , size ) # type: ignore[arg-type]
4869
4909
4870
4910
4911
+ @torch_op ("aten::zeros_like" )
4871
4912
def aten_zeros_like (self , dtype : int = - 1 ):
4872
4913
# zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4873
4914
0 commit comments