Skip to content

Commit c4a655e

Browse files
authored
feat(atenlib): atenlib function registry (#260)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #260 Registry for all atenlib functions, which supports overloading. This change implements the `torch_op` decorator for registering all functions in the ATen lib so they are discoverable. It is not designed to be used by users, but rather as an info container so code gen knows which ops are implemented. `torch_op` also compiles the functions into OnnxFunction.
1 parent 8bd6dbd commit c4a655e

File tree

4 files changed

+115
-17
lines changed

4 files changed

+115
-17
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@
1818
from typing import Any, Optional, Sequence
1919

2020
from onnxscript import BOOL, INT64
21+
from onnxscript.function_libs.torch_aten.registration import torch_op
2122
from onnxscript.onnx_opset import opset18 as op
2223
from onnxscript.onnx_types import TensorType
2324

2425

26+
@torch_op("aten::abs")
2527
def aten_abs(self):
2628
# abs(Tensor self) -> Tensor
2729

2830
return op.Abs(self)
2931

3032

33+
@torch_op("aten::acos")
3134
def aten_acos(self):
3235
# acos(Tensor self) -> Tensor
3336

3437
return op.Acos(self)
3538

3639

40+
@torch_op("aten::acosh")
3741
def aten_acosh(self):
3842
# acosh(Tensor self) -> Tensor
3943

@@ -54,6 +58,7 @@ def aten_adaptive_max_pool1d(
5458
raise NotImplementedError()
5559

5660

61+
@torch_op("aten::add")
5762
def aten_add(self, other, alpha: float = 1) -> TensorType:
5863
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
5964
if alpha != 1:
@@ -85,6 +90,7 @@ def aten_addcmul(
8590
raise NotImplementedError()
8691

8792

93+
@torch_op("aten::addmm")
8894
def aten_addmm(self, mat1, mat2, beta: float = 1, alpha: float = 1):
8995
# addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
9096

@@ -332,18 +338,21 @@ def aten_as_strided_scatter(
332338
raise NotImplementedError()
333339

334340

341+
@torch_op("aten::asin")
335342
def aten_asin(self):
336343
# asin(Tensor self) -> Tensor
337344

338345
return op.Asin(self)
339346

340347

348+
@torch_op("aten::asinh")
341349
def aten_asinh(self):
342350
# asinh(Tensor self) -> Tensor
343351

344352
return op.Asinh(self)
345353

346354

355+
@torch_op("aten::atan")
347356
def aten_atan(self):
348357
# atan(Tensor self) -> Tensor
349358

@@ -356,6 +365,7 @@ def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
356365
raise NotImplementedError()
357366

358367

368+
@torch_op("aten::atanh")
359369
def aten_atanh(self):
360370
# atanh(Tensor self) -> Tensor
361371

@@ -606,6 +616,7 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
606616
raise NotImplementedError()
607617

608618

619+
@torch_op("aten::bmm")
609620
def aten_bmm(self, mat2):
610621
# bmm(Tensor self, Tensor mat2) -> Tensor
611622

@@ -670,6 +681,7 @@ def aten_cdist(
670681
raise NotImplementedError()
671682

672683

684+
@torch_op("aten::ceil")
673685
def aten_ceil(self):
674686
# ceil(Tensor self) -> Tensor
675687

@@ -728,6 +740,7 @@ def aten_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType:
728740
raise NotImplementedError()
729741

730742

743+
@torch_op("aten::clamp")
731744
def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
732745
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
733746

@@ -752,19 +765,22 @@ def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
752765
return clamped
753766

754767

768+
@torch_op("aten::clamp_max.Scalar", overload=True)
755769
def aten_clamp_max_scalar(self, max_):
756770
# clamp_max(Tensor self, Scalar max) -> Tensor
757771

758772
max_ = op.CastLike(max_, self)
759773
return op.Clip(self, None, max_)
760774

761775

776+
@torch_op("aten::clamp_max.Tensor")
762777
def aten_clamp_max_tensor(self, max_):
763-
# clamp_max(Tensor self, Scalar max) -> Tensor
778+
# clamp_max(Tensor self, Tensor max) -> Tensor
764779

765780
return op.Min(self, max_)
766781

767782

783+
@torch_op("aten::clamp_min.Scalar", overload=True)
768784
def aten_clamp_min_scalar(self, min_):
769785
# clamp_min(Tensor self, Scalar min) -> Tensor
770786
# NOTE: min_ is a rank 0 tensor.
@@ -773,6 +789,7 @@ def aten_clamp_min_scalar(self, min_):
773789
return op.Clip(self, min_, None)
774790

775791

792+
@torch_op("aten::clamp_min.Tensor")
776793
def aten_clamp_min_tensor(self, min_):
777794
# clamp_min(Tensor self, Tensor min) -> Tensor
778795
# TODO(justinchuby): Specify the type constraints.
@@ -1017,12 +1034,14 @@ def aten_corrcoef(self: TensorType) -> TensorType:
10171034
raise NotImplementedError()
10181035

10191036

1037+
@torch_op("aten::cos")
10201038
def aten_cos(self):
10211039
# cos(Tensor self) -> Tensor
10221040

10231041
return op.Cos(self)
10241042

10251043

1044+
@torch_op("aten::cosh")
10261045
def aten_cosh(self):
10271046
# cosh(Tensor self) -> Tensor
10281047

@@ -1392,6 +1411,7 @@ def aten_divide(self: TensorType, other: TensorType) -> TensorType:
13921411
raise NotImplementedError()
13931412

13941413

1414+
@torch_op("aten::dot")
13951415
def aten_dot(self, tensor):
13961416
# dot(Tensor self, Tensor tensor) -> Tensor
13971417

@@ -1532,12 +1552,14 @@ def aten_erfinv(self: TensorType) -> TensorType:
15321552
raise NotImplementedError()
15331553

15341554

1555+
@torch_op("aten::exp")
15351556
def aten_exp(self):
15361557
# exp(Tensor self) -> Tensor
15371558

15381559
return op.Exp(self)
15391560

15401561

1562+
@torch_op("aten::exp2")
15411563
def aten_exp2(self):
15421564
# exp2(Tensor self) -> Tensor
15431565

@@ -1972,6 +1994,7 @@ def aten_gru_cell(
19721994
raise NotImplementedError()
19731995

19741996

1997+
@torch_op("aten::gt")
19751998
def aten_gt(self, other):
19761999
# gt.Tensor(Tensor self, Tensor other) -> Tensor
19772000

@@ -2588,6 +2611,7 @@ def aten_lstm_mps_backward(
25882611
raise NotImplementedError()
25892612

25902613

2614+
@torch_op("aten::lt")
25912615
def aten_lt(self, other):
25922616
# lt.Tensor(Tensor self, Tensor other) -> Tensor
25932617

@@ -2663,6 +2687,7 @@ def aten_masked_select_backward(
26632687
raise NotImplementedError()
26642688

26652689

2690+
@torch_op("aten::matmul")
26662691
def aten_matmul(self, other):
26672692
# matmul(Tensor self, Tensor other) -> Tensor
26682693

@@ -3063,6 +3088,7 @@ def aten_mkldnn_max_pool3d_backward(
30633088
raise NotImplementedError()
30643089

30653090

3091+
@torch_op("aten::mm")
30663092
def aten_mm(self, mat2):
30673093
# mm(Tensor self, Tensor mat2) -> Tensor
30683094

@@ -3129,12 +3155,14 @@ def aten_msort(self: TensorType) -> TensorType:
31293155
raise NotImplementedError()
31303156

31313157

3132-
def aten_mul(self, other) -> TensorType:
3158+
@torch_op("aten::mul")
3159+
def aten_mul(self, other):
31333160
# mul.Tensor(Tensor self, Tensor other) -> Tensor
31343161

31353162
return op.Mul(self, other)
31363163

31373164

3165+
@torch_op("aten::mul", overload=True)
31383166
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
31393167
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""
31403168

@@ -3447,6 +3475,7 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
34473475
raise NotImplementedError()
34483476

34493477

3478+
@torch_op("aten::ones")
34503479
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
34513480
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
34523481

@@ -3456,6 +3485,7 @@ def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
34563485
return op.Expand(one, size) # type: ignore[arg-type]
34573486

34583487

3488+
@torch_op("aten::ones_like")
34593489
def aten_ones_like(self, dtype: int = -1):
34603490
"""ones_like.
34613491
@@ -3942,6 +3972,7 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
39423972
raise NotImplementedError()
39433973

39443974

3975+
@torch_op("aten::repeat")
39453976
def aten_repeat(self, repeats: INT64):
39463977
# repeat(Tensor self, SymInt[] repeats) -> Tensor
39473978

@@ -4047,6 +4078,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
40474078
raise NotImplementedError()
40484079

40494080

4081+
@torch_op("aten::round")
40504082
def aten_round(self):
40514083
# round(Tensor self) -> Tensor
40524084

@@ -4157,6 +4189,7 @@ def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int)
41574189
raise NotImplementedError()
41584190

41594191

4192+
@torch_op("aten::selu")
41604193
def aten_selu(self):
41614194
# selu(Tensor self) -> Tensor
41624195

@@ -4193,12 +4226,14 @@ def aten_signbit(self: TensorType) -> TensorType:
41934226
raise NotImplementedError()
41944227

41954228

4229+
@torch_op("aten::sin")
41964230
def aten_sin(self):
41974231
# sin(Tensor self) -> Tensor
41984232

41994233
return op.Sin(self)
42004234

42014235

4236+
@torch_op("aten::sinh")
42024237
def aten_sinh(self):
42034238
# sinh(Tensor self) -> Tensor
42044239

@@ -4378,6 +4413,7 @@ def aten_stft(
43784413
raise NotImplementedError()
43794414

43804415

4416+
@torch_op("aten::sub")
43814417
def aten_sub(self, other, alpha: float = 1) -> TensorType:
43824418
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
43834419

@@ -4433,6 +4469,7 @@ def aten_symeig(
44334469
raise NotImplementedError()
44344470

44354471

4472+
@torch_op("aten::t")
44364473
def aten_t(self: TensorType) -> TensorType:
44374474
# t(Tensor(a) self) -> Tensor(a)
44384475

@@ -4465,12 +4502,14 @@ def aten_take_along_dim(
44654502
raise NotImplementedError()
44664503

44674504

4505+
@torch_op("aten::tan")
44684506
def aten_tan(self):
44694507
# tan(Tensor self) -> Tensor
44704508

44714509
return op.Tan(self)
44724510

44734511

4512+
@torch_op("aten::tanh")
44744513
def aten_tanh(self):
44754514
# tanh(Tensor self) -> Tensor
44764515

@@ -4858,6 +4897,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
48584897
raise NotImplementedError()
48594898

48604899

4900+
@torch_op("aten::zeros")
48614901
def aten_zeros(size, dtype: int = -1):
48624902
# zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
48634903

@@ -4868,6 +4908,7 @@ def aten_zeros(size, dtype: int = -1):
48684908
return op.Expand(zero, size) # type: ignore[arg-type]
48694909

48704910

4911+
@torch_op("aten::zeros_like")
48714912
def aten_zeros_like(self, dtype: int = -1):
48724913
# zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
48734914

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Optional, Sequence
2222

2323
from onnxscript import INT64
24+
from onnxscript.function_libs.torch_aten.registration import torch_op
2425
from onnxscript.onnx_opset import opset18 as op
2526
from onnxscript.onnx_types import TensorType
2627

@@ -196,6 +197,7 @@ def aten_cross_entropy_loss(
196197
raise NotImplementedError()
197198

198199

200+
@torch_op("aten::elu")
199201
def aten_elu(
200202
self,
201203
alpha: float = 1.0,
@@ -413,6 +415,7 @@ def aten_leaky_relu_backward(
413415
raise NotImplementedError()
414416

415417

418+
@torch_op("aten::linear")
416419
def aten_linear(input, weight, bias=None) -> TensorType:
417420
# linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
418421

@@ -800,6 +803,7 @@ def aten_reflection_pad3d_backward(
800803

801804

802805
# TODO(justinchuby): Use TFloat as return type
806+
@torch_op("aten::relu6")
803807
def aten_relu6(self):
804808
# relu6(Tensor self) -> Tensor
805809

0 commit comments

Comments
 (0)