Skip to content

Commit 9c75044

Browse files
authored
feat(atenlib): arange with overloads (#285)
arange ops
1 parent b635d24 commit 9c75044

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313

1414
from typing import Any, Optional, Sequence, Union
1515

16-
from onnxscript import BOOL, DOUBLE, FLOAT, INT64
16+
from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
1717
from onnxscript.function_libs.torch_aten.registration import torch_op
1818
from onnxscript.function_libs.torch_aten.typing import (
1919
TFloat,
2020
TFloatOrBFloat16,
2121
TInt,
2222
TReal,
23+
TRealUnlessFloat16OrInt8,
2324
TRealUnlessInt16OrInt8,
2425
TTensor,
2526
)
@@ -226,10 +227,65 @@ def aten_any(self: TensorType) -> TensorType:
226227
raise NotImplementedError()
227228

228229

229-
def aten_arange(end: float) -> TensorType:
230+
@torch_op("aten::arange")
231+
def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) -> TensorType:
230232
# arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
231233

232-
raise NotImplementedError()
234+
# Cast input to double if dtype is specified, because the input dtype may be e.g. bool
235+
# which Range does not support. The output type is ensured because the output
236+
# is casted to the specified dtype.
237+
if dtype != -1:
238+
end = op.Cast(end, to=DOUBLE.dtype)
239+
240+
result = op.Range(0, end, 1)
241+
if dtype != -1:
242+
result = op.Cast(result, to=dtype)
243+
244+
return result
245+
246+
247+
@torch_op("aten::arange", overload=True)
248+
def aten_arange_start(
249+
start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, dtype: int = -1
250+
) -> TensorType:
251+
# arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
252+
253+
# Cast input to double if dtype is specified, because the input dtype may be e.g. bool
254+
# which Range does not support. The output type is ensured because the output
255+
# is casted to the specified dtype.
256+
if dtype != -1:
257+
start = op.Cast(start, to=DOUBLE.dtype)
258+
end = op.Cast(end, to=DOUBLE.dtype)
259+
260+
result = op.Range(start, end, 1)
261+
if dtype != -1:
262+
result = op.Cast(result, to=dtype)
263+
264+
return result
265+
266+
267+
@torch_op("aten::arange", overload=True)
268+
def aten_arange_start_step(
269+
start: TRealUnlessFloat16OrInt8,
270+
end: TRealUnlessFloat16OrInt8,
271+
step: TRealUnlessFloat16OrInt8,
272+
dtype: int = -1,
273+
) -> TensorType:
274+
# arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
275+
276+
# Cast input to double if dtype is specified, because the input dtype may be e.g. bool
277+
# which Range does not support. The output type is ensured because the output
278+
# is casted to the specified dtype.
279+
if dtype != -1:
280+
start = op.Cast(start, to=DOUBLE.dtype)
281+
end = op.Cast(end, to=DOUBLE.dtype)
282+
step = op.Cast(step, to=DOUBLE.dtype)
283+
284+
result = op.Range(start, end, step)
285+
if dtype != -1:
286+
result = op.Cast(result, to=dtype)
287+
288+
return result
233289

234290

235291
def aten_arccos(self: TensorType) -> TensorType:

onnxscript/function_libs/torch_aten/typing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
]
4343
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
4444
_IntType = Union[INT8, INT16, INT32, INT64]
45-
_RealType = Union[
45+
RealType = Union[
4646
BFLOAT16,
4747
FLOAT16,
4848
FLOAT,
@@ -57,7 +57,10 @@
5757
TFloat = TypeVar("TFloat", bound=_FloatType)
5858
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
5959
TInt = TypeVar("TInt", bound=_IntType)
60-
TReal = TypeVar("TReal", bound=_RealType)
60+
TReal = TypeVar("TReal", bound=RealType)
6161
TRealUnlessInt16OrInt8 = TypeVar(
6262
"TRealUnlessInt16OrInt8", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64]
6363
)
64+
TRealUnlessFloat16OrInt8 = TypeVar(
65+
"TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64]
66+
)

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
195195
"addmm": core_ops.aten_addmm,
196196
"amax": (core_ops.aten_amax, _amax_amin_kwargs_wrangler),
197197
"amin": (core_ops.aten_amin, _amax_amin_kwargs_wrangler),
198+
"arange_start_step": core_ops.aten_arange_start_step,
199+
"arange_start": core_ops.aten_arange_start,
200+
"arange": core_ops.aten_arange,
198201
"asin": core_ops.aten_asin,
199202
"asinh": core_ops.aten_asinh,
200203
"atan": core_ops.aten_atan,
@@ -289,6 +292,26 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
289292

290293

291294
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
295+
skip(
296+
"arange",
297+
matcher=lambda sample: len(sample.args) != 0,
298+
reason="arange overload takes single argument",
299+
),
300+
skip(
301+
"arange",
302+
matcher=lambda sample: sample.kwargs.get("end") is not None,
303+
reason="arange overload does not support positional 'end' argument",
304+
),
305+
skip(
306+
"arange_start",
307+
matcher=lambda sample: len(sample.args) != 1,
308+
reason="arange_start overload takes two arguments (input, start)",
309+
),
310+
skip(
311+
"arange_start_step",
312+
matcher=lambda sample: len(sample.args) != 2,
313+
reason="arange_start_step overload takes three arguments (input, start, step)",
314+
),
292315
skip(
293316
"div",
294317
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
@@ -343,6 +366,15 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
343366
),
344367
)
345368

369+
duplicate_opinfo(
370+
OPS_DB,
371+
"arange",
372+
(
373+
"arange_start",
374+
"arange_start_step",
375+
),
376+
)
377+
346378

347379
# END OF SECTION TO MODIFY #####################################################
348380

0 commit comments

Comments
 (0)