Skip to content

Commit 47fe75f

Browse files
authored
feat(atenlib): ops 7/n (#279)
upsample_nearest2d, amax, amin and adaptive pools upsample_nearest2d was tested locally on opset17. Opset18 cannot be tested yet because onnx runtime does not support it yet.
1 parent 677ca1f commit 47fe75f

File tree

3 files changed

+200
-51
lines changed

3 files changed

+200
-51
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,6 @@ def aten_acosh(self: TFloat) -> TFloat:
4848
return op.Acosh(self)
4949

5050

51-
def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
52-
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
53-
54-
raise NotImplementedError()
55-
56-
57-
def aten_adaptive_max_pool1d(
58-
self: TensorType, output_size: Sequence[int]
59-
) -> tuple[TensorType, TensorType]:
60-
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
61-
62-
raise NotImplementedError()
63-
64-
6551
@torch_op("aten::add")
6652
def aten_add(self: TReal, other: TReal, alpha: float = 1) -> TReal:
6753
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
@@ -198,20 +184,20 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
198184
raise NotImplementedError()
199185

200186

201-
def aten_amax(
202-
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
203-
) -> TensorType:
187+
# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
188+
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
204189
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
205190

206-
raise NotImplementedError()
191+
# TODO(justinchuby): Make dim optional, keepdim bool
192+
return op.ReduceMax(self, dim, keepdims=keepdim)
207193

208194

209-
def aten_amin(
210-
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
211-
) -> TensorType:
195+
# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
196+
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
212197
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
213198

214-
raise NotImplementedError()
199+
# TODO(justinchuby): Make dim optional, keepdim bool
200+
return op.ReduceMin(self, dim, keepdims=keepdim)
215201

216202

217203
def aten_aminmax(
@@ -4181,10 +4167,11 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
41814167
return op.Reciprocal(op.Sqrt(self))
41824168

41834169

4184-
def aten_rsub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
4170+
@torch_op("aten::rsub")
4171+
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
41854172
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
41864173

4187-
raise NotImplementedError()
4174+
return op.Sub(other, op.Mul(self, alpha))
41884175

41894176

41904177
def aten_scalar_tensor(s: float) -> TensorType:

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,65 @@
2323
from onnxscript.onnx_types import TensorType
2424

2525

26-
def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
26+
@torch_op("aten::aten_adaptive_avg_pool1d")
27+
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
28+
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
29+
30+
# assert output_size == [1]
31+
# TODO(justinchuby): Specify input constraints
32+
33+
if op.Size(op.Shape(self)) == 2:
34+
# Unbatched case
35+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
36+
pooled = op.GlobalAveragePool(self)
37+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
38+
else:
39+
result = op.GlobalAveragePool(self)
40+
41+
return result
42+
43+
44+
@torch_op("aten::aten_adaptive_avg_pool2d")
45+
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
2746
# adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
2847

29-
raise NotImplementedError()
48+
# assert output_size == [1, 1]
49+
# TODO(justinchuby): Specify input constraints
50+
51+
if op.Size(op.Shape(self)) == 3:
52+
# Unbatched case
53+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
54+
pooled = op.GlobalAveragePool(self)
55+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
56+
else:
57+
result = op.GlobalAveragePool(self)
58+
59+
return result
3060

3161

32-
def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType:
62+
@torch_op("aten::aten_adaptive_avg_pool3d")
63+
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
3364
# adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
3465

66+
# assert output_size == [1, 1, 1]
67+
# TODO(justinchuby): Specify input constraints
68+
69+
if op.Size(op.Shape(self)) == 4:
70+
# Unbatched case
71+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
72+
pooled = op.GlobalAveragePool(self)
73+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
74+
else:
75+
result = op.GlobalAveragePool(self)
76+
77+
return result
78+
79+
80+
def aten_adaptive_max_pool1d(
81+
self: TensorType, output_size: Sequence[int]
82+
) -> tuple[TensorType, TensorType]:
83+
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
84+
3585
raise NotImplementedError()
3686

3787

@@ -1162,15 +1212,29 @@ def aten_upsample_nearest1d_backward(
11621212
raise NotImplementedError()
11631213

11641214

1215+
@torch_op("aten::upsample_nearest2d")
11651216
def aten_upsample_nearest2d(
1166-
self: TensorType,
1167-
output_size: INT64,
1217+
self: TReal,
1218+
size: INT64,
11681219
scales_h: Optional[float] = None,
11691220
scales_w: Optional[float] = None,
1170-
) -> TensorType:
1221+
) -> TReal:
11711222
# upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
11721223

1173-
raise NotImplementedError()
1224+
self_shape = op.Shape(self)
1225+
batch_channel = self_shape[:2] # type: ignore[index]
1226+
output_size = op.Concat(batch_channel, size, axis=0)
1227+
1228+
# TODO(justinchuby): Conditionally use scales
1229+
1230+
return op.Resize(
1231+
self,
1232+
None,
1233+
None,
1234+
output_size,
1235+
mode="nearest",
1236+
coordinate_transformation_mode="asymmetric",
1237+
)
11741238

11751239

11761240
def aten_upsample_nearest2d_backward(

0 commit comments

Comments
 (0)