|
23 | 23 | from onnxscript.onnx_types import TensorType |
24 | 24 |
|
25 | 25 |
|
26 | | -def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType: |
| 26 | +@torch_op("aten::aten_adaptive_avg_pool1d") |
| 27 | +def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: |
| 28 | + # adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor |
| 29 | + |
| 30 | + # assert output_size == [1] |
| 31 | + # TODO(justinchuby): Specify input constraints |
| 32 | + |
| 33 | + if op.Size(op.Shape(self)) == 2: |
| 34 | + # Unbatched case |
| 35 | + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) |
| 36 | + pooled = op.GlobalAveragePool(self) |
| 37 | + result = op.Squeeze(pooled, op.Constant(value_ints=[0])) |
| 38 | + else: |
| 39 | + result = op.GlobalAveragePool(self) |
| 40 | + |
| 41 | + return result |
| 42 | + |
| 43 | + |
| 44 | +@torch_op("aten::aten_adaptive_avg_pool2d") |
| 45 | +def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: |
27 | 46 | # adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor |
28 | 47 |
|
29 | | - raise NotImplementedError() |
| 48 | + # assert output_size == [1, 1] |
| 49 | + # TODO(justinchuby): Specify input constraints |
| 50 | + |
| 51 | + if op.Size(op.Shape(self)) == 3: |
| 52 | + # Unbatched case |
| 53 | + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) |
| 54 | + pooled = op.GlobalAveragePool(self) |
| 55 | + result = op.Squeeze(pooled, op.Constant(value_ints=[0])) |
| 56 | + else: |
| 57 | + result = op.GlobalAveragePool(self) |
| 58 | + |
| 59 | + return result |
30 | 60 |
|
31 | 61 |
|
32 | | -def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType: |
| 62 | +@torch_op("aten::aten_adaptive_avg_pool3d") |
| 63 | +def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat: |
33 | 64 | # adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor |
34 | 65 |
|
| 66 | + # assert output_size == [1, 1, 1] |
| 67 | + # TODO(justinchuby): Specify input constraints |
| 68 | + |
| 69 | + if op.Size(op.Shape(self)) == 4: |
| 70 | + # Unbatched case |
| 71 | + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) |
| 72 | + pooled = op.GlobalAveragePool(self) |
| 73 | + result = op.Squeeze(pooled, op.Constant(value_ints=[0])) |
| 74 | + else: |
| 75 | + result = op.GlobalAveragePool(self) |
| 76 | + |
| 77 | + return result |
| 78 | + |
| 79 | + |
| 80 | +def aten_adaptive_max_pool1d( |
| 81 | + self: TensorType, output_size: Sequence[int] |
| 82 | +) -> tuple[TensorType, TensorType]: |
| 83 | + # adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) |
| 84 | + |
35 | 85 | raise NotImplementedError() |
36 | 86 |
|
37 | 87 |
|
@@ -1162,15 +1212,29 @@ def aten_upsample_nearest1d_backward( |
1162 | 1212 | raise NotImplementedError() |
1163 | 1213 |
|
1164 | 1214 |
|
| 1215 | +@torch_op("aten::upsample_nearest2d") |
1165 | 1216 | def aten_upsample_nearest2d( |
1166 | | - self: TensorType, |
1167 | | - output_size: INT64, |
| 1217 | + self: TReal, |
| 1218 | + size: INT64, |
1168 | 1219 | scales_h: Optional[float] = None, |
1169 | 1220 | scales_w: Optional[float] = None, |
1170 | | -) -> TensorType: |
| 1221 | +) -> TReal: |
1171 | 1222 | # upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor |
1172 | 1223 |
|
1173 | | - raise NotImplementedError() |
| 1224 | + self_shape = op.Shape(self) |
| 1225 | + batch_channel = self_shape[:2] # type: ignore[index] |
| 1226 | + output_size = op.Concat(batch_channel, size, axis=0) |
| 1227 | + |
| 1228 | + # TODO(justinchuby): Conditionally use scales |
| 1229 | + |
| 1230 | + return op.Resize( |
| 1231 | + self, |
| 1232 | + None, |
| 1233 | + None, |
| 1234 | + output_size, |
| 1235 | + mode="nearest", |
| 1236 | + coordinate_transformation_mode="asymmetric", |
| 1237 | + ) |
1174 | 1238 |
|
1175 | 1239 |
|
1176 | 1240 | def aten_upsample_nearest2d_backward( |
|
0 commit comments