Skip to content

Commit 2512120

Browse files
authored
feat(atenlib): ops 4/n (#256)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #256 * #255 * #252 Some more math and matrix ops
1 parent f7ac851 commit 2512120

File tree

3 files changed

+122
-49
lines changed

3 files changed

+122
-49
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -332,22 +332,22 @@ def aten_as_strided_scatter(
332332
raise NotImplementedError()
333333

334334

335-
def aten_asin(self: TensorType) -> TensorType:
335+
def aten_asin(self):
336336
# asin(Tensor self) -> Tensor
337337

338-
raise NotImplementedError()
338+
return op.Asin(self)
339339

340340

341-
def aten_asinh(self: TensorType) -> TensorType:
341+
def aten_asinh(self):
342342
# asinh(Tensor self) -> Tensor
343343

344-
raise NotImplementedError()
344+
return op.Asinh(self)
345345

346346

347-
def aten_atan(self: TensorType) -> TensorType:
347+
def aten_atan(self):
348348
# atan(Tensor self) -> Tensor
349349

350-
raise NotImplementedError()
350+
return op.Atan(self)
351351

352352

353353
def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
@@ -356,10 +356,10 @@ def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
356356
raise NotImplementedError()
357357

358358

359-
def aten_atanh(self: TensorType) -> TensorType:
359+
def aten_atanh(self):
360360
# atanh(Tensor self) -> Tensor
361361

362-
raise NotImplementedError()
362+
return op.Atanh(self)
363363

364364

365365
def aten_atleast_1d(self: TensorType) -> TensorType:
@@ -670,16 +670,10 @@ def aten_cdist(
670670
raise NotImplementedError()
671671

672672

673-
def aten_ceil(self: TensorType) -> TensorType:
673+
def aten_ceil(self):
674674
# ceil(Tensor self) -> Tensor
675675

676-
raise NotImplementedError()
677-
678-
679-
def aten_celu(self: TensorType, alpha: float = 1.0) -> TensorType:
680-
# celu(Tensor self, Scalar alpha=1.0) -> Tensor
681-
682-
raise NotImplementedError()
676+
return op.Ceil(self)
683677

684678

685679
def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType:
@@ -785,14 +779,6 @@ def aten_clamp_min_tensor(self, min_):
785779
return op.Max(self, min_)
786780

787781

788-
def aten_clip(
789-
self: TensorType, min: Optional[float] = None, max: Optional[float] = None
790-
) -> TensorType:
791-
# clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
792-
793-
raise NotImplementedError()
794-
795-
796782
def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
797783
# clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
798784

@@ -1031,16 +1017,16 @@ def aten_corrcoef(self: TensorType) -> TensorType:
10311017
raise NotImplementedError()
10321018

10331019

1034-
def aten_cos(self: TensorType) -> TensorType:
1020+
def aten_cos(self):
10351021
# cos(Tensor self) -> Tensor
10361022

1037-
raise NotImplementedError()
1023+
return op.Cos(self)
10381024

10391025

1040-
def aten_cosh(self: TensorType) -> TensorType:
1026+
def aten_cosh(self):
10411027
# cosh(Tensor self) -> Tensor
10421028

1043-
raise NotImplementedError()
1029+
return op.Cosh(self)
10441030

10451031

10461032
def aten_cosine_embedding_loss(
@@ -1406,10 +1392,10 @@ def aten_divide(self: TensorType, other: TensorType) -> TensorType:
14061392
raise NotImplementedError()
14071393

14081394

1409-
def aten_dot(self: TensorType, tensor: TensorType) -> TensorType:
1395+
def aten_dot(self, tensor):
14101396
# dot(Tensor self, Tensor tensor) -> Tensor
14111397

1412-
raise NotImplementedError()
1398+
return op.MatMul(self, tensor)
14131399

14141400

14151401
def aten_dropout(input: TensorType, p: float, train: bool) -> TensorType:
@@ -1546,16 +1532,18 @@ def aten_erfinv(self: TensorType) -> TensorType:
15461532
raise NotImplementedError()
15471533

15481534

1549-
def aten_exp(self: TensorType) -> TensorType:
1535+
def aten_exp(self):
15501536
# exp(Tensor self) -> Tensor
15511537

1552-
raise NotImplementedError()
1538+
return op.Exp(self)
15531539

15541540

1555-
def aten_exp2(self: TensorType) -> TensorType:
1541+
def aten_exp2(self):
15561542
# exp2(Tensor self) -> Tensor
15571543

1558-
raise NotImplementedError()
1544+
two = op.Constant(value_int=2)
1545+
two = op.CastLike(two, self) # type: ignore[arg-type]
1546+
return op.Pow(two, self) # type: ignore[arg-type]
15591547

15601548

15611549
def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
@@ -4205,22 +4193,16 @@ def aten_signbit(self: TensorType) -> TensorType:
42054193
raise NotImplementedError()
42064194

42074195

4208-
def aten_sin(self: TensorType) -> TensorType:
4196+
def aten_sin(self):
42094197
# sin(Tensor self) -> Tensor
42104198

4211-
raise NotImplementedError()
4212-
4199+
return op.Sin(self)
42134200

4214-
def aten_sinc(self: TensorType) -> TensorType:
4215-
# sinc(Tensor self) -> Tensor
42164201

4217-
raise NotImplementedError()
4218-
4219-
4220-
def aten_sinh(self: TensorType) -> TensorType:
4202+
def aten_sinh(self):
42214203
# sinh(Tensor self) -> Tensor
42224204

4223-
raise NotImplementedError()
4205+
return op.Sinh(self)
42244206

42254207

42264208
def aten_slice(
@@ -4483,16 +4465,16 @@ def aten_take_along_dim(
44834465
raise NotImplementedError()
44844466

44854467

4486-
def aten_tan(self: TensorType) -> TensorType:
4468+
def aten_tan(self):
44874469
# tan(Tensor self) -> Tensor
44884470

4489-
raise NotImplementedError()
4471+
return op.Tan(self)
44904472

44914473

4492-
def aten_tanh(self: TensorType) -> TensorType:
4474+
def aten_tanh(self):
44934475
# tanh(Tensor self) -> Tensor
44944476

4495-
raise NotImplementedError()
4477+
return op.Tanh(self)
44964478

44974479

44984480
def aten_tensordot(

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def aten_binary_cross_entropy_backward(
150150
raise NotImplementedError()
151151

152152

153+
def aten_celu(self, alpha: float = 1.0):
154+
# celu(Tensor self, Scalar alpha=1.0) -> Tensor
155+
156+
raise NotImplementedError()
157+
158+
153159
def aten_col2im(
154160
self: TensorType,
155161
output_size: INT64,

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def skip(
119119
return DecorateMeta(
120120
op_name=op_name,
121121
variant_name=variant_name,
122-
decorator=unittest.skip(f"Don't care: {reason}"),
122+
decorator=unittest.skip(f"Skip: {reason}"),
123123
dtypes=dtypes,
124124
reason=reason,
125125
matcher=matcher,
@@ -166,10 +166,20 @@ def wrapped(fn):
166166
"acosh": core_ops.aten_acosh,
167167
"add": core_ops.aten_add,
168168
"addmm": core_ops.aten_addmm,
169+
"asin": core_ops.aten_asin,
170+
"asinh": core_ops.aten_asinh,
171+
"atan": core_ops.aten_atan,
172+
"atanh": core_ops.aten_atanh,
169173
"bmm": core_ops.aten_bmm,
174+
"ceil": core_ops.aten_ceil,
170175
"clamp_max": core_ops.aten_clamp_max_tensor,
171176
"clamp_min": core_ops.aten_clamp_min_tensor,
172177
"clamp": core_ops.aten_clamp,
178+
"cos": core_ops.aten_cos,
179+
"cosh": core_ops.aten_cosh,
180+
"dot": core_ops.aten_dot,
181+
"exp": core_ops.aten_exp,
182+
"exp2": core_ops.aten_exp2,
173183
"gt": core_ops.aten_gt,
174184
"lt": core_ops.aten_lt,
175185
"matmul": core_ops.aten_matmul,
@@ -183,8 +193,12 @@ def wrapped(fn):
183193
"ones": core_ops.aten_ones,
184194
"repeat": core_ops.aten_repeat,
185195
"round": core_ops.aten_round,
196+
"sin": core_ops.aten_sin,
197+
"sinh": core_ops.aten_sinh,
186198
"sub": core_ops.aten_sub,
187199
"t": core_ops.aten_t,
200+
"tan": core_ops.aten_tan,
201+
"tanh": core_ops.aten_tanh,
188202
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
189203
}
190204

@@ -206,21 +220,72 @@ def wrapped(fn):
206220
"addmm",
207221
dtypes=[torch.uint8, torch.int8, torch.int16],
208222
reason="MatMul is not defined on int16/int8/uint8 tensors",
223+
# TODO(justinchuby): Use MatMulInteger
209224
),
210225
xfail(
211226
"addmm",
212227
variant_name="decomposed",
213228
dtypes=[torch.uint8, torch.int8, torch.int16],
214229
reason="MatMul is not defined on int16/int8/uint8 tensors",
215230
),
231+
xfail(
232+
"asin",
233+
dtypes=BOOL_TYPES + INT_TYPES,
234+
reason="Asin is not defined on bool or int tensors",
235+
),
236+
xfail(
237+
"asinh",
238+
dtypes=BOOL_TYPES + INT_TYPES,
239+
reason="Asinh is not defined on bool or int tensors",
240+
),
241+
xfail(
242+
"atan",
243+
dtypes=BOOL_TYPES + INT_TYPES,
244+
reason="Atan is not defined on bool or int tensors",
245+
),
246+
xfail(
247+
"atanh",
248+
dtypes=BOOL_TYPES + INT_TYPES,
249+
reason="Atanh is not defined on bool or int tensors",
250+
),
216251
xfail(
217252
"bmm",
218253
dtypes=[torch.uint8, torch.int8, torch.int16],
219254
reason="MatMul is not defined on int16/int8/uint8 tensors",
220255
),
256+
xfail(
257+
"ceil",
258+
dtypes=BOOL_TYPES + INT_TYPES,
259+
reason="Ceil is not defined on bool or int tensors",
260+
),
221261
skip("clamp", reason="Enable when onnxscript errors are fixed"),
222262
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
223263
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
264+
xfail(
265+
"cos",
266+
dtypes=BOOL_TYPES + INT_TYPES,
267+
reason="Cos is not defined on bool or int tensors",
268+
),
269+
xfail(
270+
"cosh",
271+
dtypes=BOOL_TYPES + INT_TYPES,
272+
reason="Cosh is not defined on bool or int tensors",
273+
),
274+
xfail(
275+
"dot",
276+
dtypes=[torch.uint8, torch.int8, torch.int16],
277+
reason="MatMul is not defined on int16/int8/uint8 tensors",
278+
),
279+
xfail(
280+
"exp",
281+
dtypes=BOOL_TYPES + INT_TYPES,
282+
reason="Exp is not defined on bool or int tensors",
283+
),
284+
xfail(
285+
"exp2",
286+
dtypes=BOOL_TYPES + INT_TYPES,
287+
reason="Pow is not defined on bool or int tensors",
288+
),
224289
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
225290
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
226291
xfail(
@@ -264,7 +329,27 @@ def wrapped(fn):
264329
xfail(
265330
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
266331
),
332+
xfail(
333+
"sin",
334+
dtypes=BOOL_TYPES + INT_TYPES,
335+
reason="Sin is not defined on bool or int tensors",
336+
),
337+
xfail(
338+
"sinh",
339+
dtypes=BOOL_TYPES + INT_TYPES,
340+
reason="Sinh is not defined on bool or int tensors",
341+
),
267342
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
343+
xfail(
344+
"tan",
345+
dtypes=BOOL_TYPES + INT_TYPES,
346+
reason="Tan is not defined on bool or int tensors",
347+
),
348+
xfail(
349+
"tanh",
350+
dtypes=BOOL_TYPES + INT_TYPES,
351+
reason="Tanh is not defined on bool or int tensors",
352+
),
268353
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
269354
)
270355

0 commit comments

Comments
 (0)