Skip to content

Commit 5a3713e

Browse files
committed
Merge branch 'main' into xiaowu/trySome2n
2 parents 72d6381 + dd3d747 commit 5a3713e

File tree

4 files changed

+258
-66
lines changed

4 files changed

+258
-66
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,6 @@ def aten_acosh(self: TFloat) -> TFloat:
4848
return op.Acosh(self)
4949

5050

51-
def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
52-
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
53-
54-
raise NotImplementedError()
55-
56-
57-
def aten_adaptive_max_pool1d(
58-
self: TensorType, output_size: Sequence[int]
59-
) -> tuple[TensorType, TensorType]:
60-
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
61-
62-
raise NotImplementedError()
63-
64-
6551
@torch_op("aten::add")
6652
def aten_add(self: TReal, other: TReal, alpha: float = 1) -> TReal:
6753
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
@@ -198,20 +184,20 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
198184
raise NotImplementedError()
199185

200186

201-
def aten_amax(
202-
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
203-
) -> TensorType:
187+
# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
188+
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
204189
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
205190

206-
raise NotImplementedError()
191+
# TODO(justinchuby): Make dim optional, keepdim bool
192+
return op.ReduceMax(self, dim, keepdims=keepdim)
207193

208194

209-
def aten_amin(
210-
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
211-
) -> TensorType:
195+
# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
196+
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
212197
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
213198

214-
raise NotImplementedError()
199+
# TODO(justinchuby): Make dim optional, keepdim bool
200+
return op.ReduceMin(self, dim, keepdims=keepdim)
215201

216202

217203
def aten_aminmax(
@@ -2190,10 +2176,21 @@ def aten_index_reduce(
21902176
raise NotImplementedError()
21912177

21922178

2193-
def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType:
2179+
# FIXME(#277): Script when attributes can come before inputs
2180+
@torch_op("aten::index_select", trace_only=True)
2181+
def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor:
21942182
# index_select(Tensor self, int dim, Tensor index) -> Tensor
21952183

2196-
raise NotImplementedError()
2184+
if op.Size(op.Shape(self)) == 0:
2185+
result = self
2186+
else:
2187+
# Index can be a scalar. Reshape it to a rank 1 tensor.
2188+
index = op.Reshape(index, op.Constant(value_floats=[-1]))
2189+
index = op.Cast(index, to=INT64.dtype)
2190+
2191+
result = op.Gather(self, index, axis=dim)
2192+
2193+
return result
21972194

21982195

21992196
def aten_index_select_backward(
@@ -4194,10 +4191,11 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
41944191
return op.Reciprocal(op.Sqrt(self))
41954192

41964193

4197-
def aten_rsub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
4194+
@torch_op("aten::rsub")
4195+
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
41984196
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
41994197

4200-
raise NotImplementedError()
4198+
return op.Sub(other, op.Mul(self, alpha))
42014199

42024200

42034201
def aten_scalar_tensor(s: float) -> TensorType:
@@ -4698,11 +4696,26 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
46984696
raise NotImplementedError()
46994697

47004698

4699+
@torch_op("aten::transpose", trace_only=True)
47014700
def aten_transpose(self, dim0: int, dim1: int):
47024701
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
47034702

47044703
# FIXME(justinchuby): onnxscript raises Unsupported expression type
4705-
return op.Transpose(self, [dim0, dim1])
4704+
# Script the function when this is fixed
4705+
self_rank = op.Size(op.Shape(self))
4706+
4707+
if self_rank == 0:
4708+
result = self
4709+
else:
4710+
# Python code, change when onnxscript supports this
4711+
self_rank_val = self_rank.value # type: ignore[attr-defined]
4712+
dims = list(range(self_rank_val))
4713+
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
4714+
# Python code ends
4715+
4716+
result = op.Transpose(self, perm=dims)
4717+
4718+
return result
47064719

47074720

47084721
def aten_triangular_solve(

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,65 @@
2323
from onnxscript.onnx_types import TensorType
2424

2525

26-
def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
26+
@torch_op("aten::aten_adaptive_avg_pool1d")
27+
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
28+
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
29+
30+
# assert output_size == [1]
31+
# TODO(justinchuby): Specify input constraints
32+
33+
if op.Size(op.Shape(self)) == 2:
34+
# Unbatched case
35+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
36+
pooled = op.GlobalAveragePool(self)
37+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
38+
else:
39+
result = op.GlobalAveragePool(self)
40+
41+
return result
42+
43+
44+
@torch_op("aten::aten_adaptive_avg_pool2d")
45+
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
2746
# adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
2847

29-
raise NotImplementedError()
48+
# assert output_size == [1, 1]
49+
# TODO(justinchuby): Specify input constraints
50+
51+
if op.Size(op.Shape(self)) == 3:
52+
# Unbatched case
53+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
54+
pooled = op.GlobalAveragePool(self)
55+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
56+
else:
57+
result = op.GlobalAveragePool(self)
58+
59+
return result
3060

3161

32-
def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType:
62+
@torch_op("aten::aten_adaptive_avg_pool3d")
63+
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
3364
# adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
3465

66+
# assert output_size == [1, 1, 1]
67+
# TODO(justinchuby): Specify input constraints
68+
69+
if op.Size(op.Shape(self)) == 4:
70+
# Unbatched case
71+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
72+
pooled = op.GlobalAveragePool(self)
73+
result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
74+
else:
75+
result = op.GlobalAveragePool(self)
76+
77+
return result
78+
79+
80+
def aten_adaptive_max_pool1d(
81+
self: TensorType, output_size: Sequence[int]
82+
) -> tuple[TensorType, TensorType]:
83+
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
84+
3585
raise NotImplementedError()
3686

3787

@@ -1162,15 +1212,29 @@ def aten_upsample_nearest1d_backward(
11621212
raise NotImplementedError()
11631213

11641214

1215+
@torch_op("aten::upsample_nearest2d")
11651216
def aten_upsample_nearest2d(
1166-
self: TensorType,
1167-
output_size: INT64,
1217+
self: TReal,
1218+
size: INT64,
11681219
scales_h: Optional[float] = None,
11691220
scales_w: Optional[float] = None,
1170-
) -> TensorType:
1221+
) -> TReal:
11711222
# upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
11721223

1173-
raise NotImplementedError()
1224+
self_shape = op.Shape(self)
1225+
batch_channel = self_shape[:2] # type: ignore[index]
1226+
output_size = op.Concat(batch_channel, size, axis=0)
1227+
1228+
# TODO(justinchuby): Conditionally use scales
1229+
1230+
return op.Resize(
1231+
self,
1232+
None,
1233+
None,
1234+
output_size,
1235+
mode="nearest",
1236+
coordinate_transformation_mode="asymmetric",
1237+
)
11741238

11751239

11761240
def aten_upsample_nearest2d_backward(

onnxscript/function_libs/torch_aten/registration.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,33 @@ def __repr__(self):
4848

4949

5050
def torch_op(
51-
name, overload: bool = False, registry: Optional[Registry] = None
52-
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
53-
"""Register a torch op."""
51+
name,
52+
*,
53+
overload: bool = False,
54+
registry: Optional[Registry] = None,
55+
trace_only: bool = False,
56+
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]:
57+
"""Register a torch op.
58+
59+
Args:
60+
name: ATen name of the function. E.g. "aten::add".
61+
overload: Whether the function is an overload (not default).
62+
registry: Registry to register the function to. If None, the default registry is used.
63+
trace_only: Whether the function should only be traced and not compiled.
64+
"""
5465
if registry is None:
5566
registry = default_registry
5667

57-
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction:
68+
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]:
5869

59-
# Compile the function
60-
compiled = onnxscript.script()(func)
70+
if trace_only:
71+
processed_func = func
72+
else:
73+
# Compile the function
74+
processed_func = onnxscript.script()(func)
6175

6276
assert registry is not None
63-
registry.register(compiled, name, overload=overload)
64-
return compiled
77+
registry.register(processed_func, name, overload=overload)
78+
return processed_func
6579

6680
return wrapper

0 commit comments

Comments
 (0)