|
17 | 17 |
|
18 | 18 | from typing import Any, Optional, Sequence
|
19 | 19 |
|
20 |
| -from onnxscript import INT64 |
| 20 | +from onnxscript import BOOL, INT64 |
21 | 21 | from onnxscript.onnx_opset import default_opset as op
|
22 | 22 | from onnxscript.onnx_types import TensorType
|
23 | 23 |
|
@@ -60,10 +60,11 @@ def aten_adaptive_max_pool1d(
|
60 | 60 | raise NotImplementedError()
|
61 | 61 |
|
62 | 62 |
|
63 |
| -def aten_add(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType: |
| 63 | +def aten_add(self, other, alpha: float = 1) -> TensorType: |
64 | 64 | # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
65 |
| - |
66 |
| - raise NotImplementedError() |
| 65 | + if alpha != 1: |
| 66 | + other = op.Mul(other, alpha) # type: ignore[arg-type] |
| 67 | + return op.Add(self, other) |
67 | 68 |
|
68 | 69 |
|
69 | 70 | def aten_addbmm(
|
@@ -3109,10 +3110,19 @@ def aten_msort(self: TensorType) -> TensorType:
|
3109 | 3110 | raise NotImplementedError()
|
3110 | 3111 |
|
3111 | 3112 |
|
3112 |
| -def aten_mul(self: TensorType, other: TensorType) -> TensorType: |
| 3113 | +def aten_mul(self, other) -> TensorType: |
3113 | 3114 | # mul.Tensor(Tensor self, Tensor other) -> Tensor
|
3114 | 3115 |
|
3115 |
| - raise NotImplementedError() |
| 3116 | + return op.Mul(self, other) |
| 3117 | + |
| 3118 | + |
| 3119 | +def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: |
| 3120 | + """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" |
| 3121 | + |
| 3122 | + # TODO(justinchuby): Handle cases where type reconcilation is not enough, |
| 3123 | + # since different ONNX operators are used based on different data types. |
| 3124 | + |
| 3125 | + return op.And(self, other) |
3116 | 3126 |
|
3117 | 3127 |
|
3118 | 3128 | def aten_multinomial(
|
@@ -4339,10 +4349,13 @@ def aten_stft(
|
4339 | 4349 | raise NotImplementedError()
|
4340 | 4350 |
|
4341 | 4351 |
|
4342 |
| -def aten_sub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType: |
| 4352 | +def aten_sub(self, other, alpha: float = 1) -> TensorType: |
4343 | 4353 | # sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
4344 | 4354 |
|
4345 |
| - raise NotImplementedError() |
| 4355 | + if alpha != 1: |
| 4356 | + other = op.Mul(other, alpha) # type: ignore[arg-type] |
| 4357 | + |
| 4358 | + return op.Sub(self, other) |
4346 | 4359 |
|
4347 | 4360 |
|
4348 | 4361 | def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
|
|
0 commit comments