|
13 | 13 |
|
14 | 14 | from typing import Any, Optional, Sequence, Union |
15 | 15 |
|
16 | | -from onnxscript import BOOL, DOUBLE, FLOAT, INT64 |
| 16 | +from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64 |
17 | 17 | from onnxscript.function_libs.torch_aten.registration import torch_op |
18 | 18 | from onnxscript.function_libs.torch_aten.typing import ( |
19 | 19 | TFloat, |
20 | 20 | TFloatOrBFloat16, |
21 | 21 | TInt, |
22 | 22 | TReal, |
| 23 | + TRealUnlessFloat16OrInt8, |
23 | 24 | TRealUnlessInt16OrInt8, |
24 | 25 | TTensor, |
25 | 26 | ) |
@@ -226,10 +227,65 @@ def aten_any(self: TensorType) -> TensorType: |
226 | 227 | raise NotImplementedError() |
227 | 228 |
|
228 | 229 |
|
229 | | -def aten_arange(end: float) -> TensorType: |
| 230 | +@torch_op("aten::arange") |
| 231 | +def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) -> TensorType: |
230 | 232 | # arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
231 | 233 |
|
232 | | - raise NotImplementedError() |
| 234 | + # Cast input to double if dtype is specified, because the input dtype may be e.g. bool |
| 235 | + # which Range does not support. The output type is ensured because the output |
| 236 | + # is casted to the specified dtype. |
| 237 | + if dtype != -1: |
| 238 | + end = op.Cast(end, to=DOUBLE.dtype) |
| 239 | + |
| 240 | + result = op.Range(0, end, 1) |
| 241 | + if dtype != -1: |
| 242 | + result = op.Cast(result, to=dtype) |
| 243 | + |
| 244 | + return result |
| 245 | + |
| 246 | + |
| 247 | +@torch_op("aten::arange", overload=True) |
| 248 | +def aten_arange_start( |
| 249 | + start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, dtype: int = -1 |
| 250 | +) -> TensorType: |
| 251 | + # arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
| 252 | + |
| 253 | + # Cast input to double if dtype is specified, because the input dtype may be e.g. bool |
| 254 | + # which Range does not support. The output type is ensured because the output |
| 255 | + # is casted to the specified dtype. |
| 256 | + if dtype != -1: |
| 257 | + start = op.Cast(start, to=DOUBLE.dtype) |
| 258 | + end = op.Cast(end, to=DOUBLE.dtype) |
| 259 | + |
| 260 | + result = op.Range(start, end, 1) |
| 261 | + if dtype != -1: |
| 262 | + result = op.Cast(result, to=dtype) |
| 263 | + |
| 264 | + return result |
| 265 | + |
| 266 | + |
| 267 | +@torch_op("aten::arange", overload=True) |
| 268 | +def aten_arange_start_step( |
| 269 | + start: TRealUnlessFloat16OrInt8, |
| 270 | + end: TRealUnlessFloat16OrInt8, |
| 271 | + step: TRealUnlessFloat16OrInt8, |
| 272 | + dtype: int = -1, |
| 273 | +) -> TensorType: |
| 274 | + # arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
| 275 | + |
| 276 | + # Cast input to double if dtype is specified, because the input dtype may be e.g. bool |
| 277 | + # which Range does not support. The output type is ensured because the output |
| 278 | + # is casted to the specified dtype. |
| 279 | + if dtype != -1: |
| 280 | + start = op.Cast(start, to=DOUBLE.dtype) |
| 281 | + end = op.Cast(end, to=DOUBLE.dtype) |
| 282 | + step = op.Cast(step, to=DOUBLE.dtype) |
| 283 | + |
| 284 | + result = op.Range(start, end, step) |
| 285 | + if dtype != -1: |
| 286 | + result = op.Cast(result, to=dtype) |
| 287 | + |
| 288 | + return result |
233 | 289 |
|
234 | 290 |
|
235 | 291 | def aten_arccos(self: TensorType) -> TensorType: |
|
0 commit comments