Skip to content

Commit dd3d747

Browse files
authored
feat(atenlib): index_select; trace_only ops (#274)
Implement `aten_index_select`; enable tests for transpose. This change also added the **`trace_only`** argument to mark functions as trace only to work around functions that cannot be compiled. The logic can then be tested, but the argument should be removed later when onnxscript extends support for the syntax.
1 parent 47fe75f commit dd3d747

File tree

3 files changed

+58
-13
lines changed

3 files changed

+58
-13
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,10 +2165,21 @@ def aten_index_reduce(
21652165
raise NotImplementedError()
21662166

21672167

2168-
def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType:
2168+
# FIXME(#277): Script when attributes can come before inputs
2169+
@torch_op("aten::index_select", trace_only=True)
2170+
def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor:
21692171
# index_select(Tensor self, int dim, Tensor index) -> Tensor
21702172

2171-
raise NotImplementedError()
2173+
if op.Size(op.Shape(self)) == 0:
2174+
result = self
2175+
else:
2176+
# Index can be a scalar. Reshape it to a rank 1 tensor.
2177+
index = op.Reshape(index, op.Constant(value_floats=[-1]))
2178+
index = op.Cast(index, to=INT64.dtype)
2179+
2180+
result = op.Gather(self, index, axis=dim)
2181+
2182+
return result
21722183

21732184

21742185
def aten_index_select_backward(
@@ -4670,11 +4681,26 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
46704681
raise NotImplementedError()
46714682

46724683

4684+
@torch_op("aten::transpose", trace_only=True)
46734685
def aten_transpose(self, dim0: int, dim1: int):
46744686
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
46754687

46764688
# FIXME(justinchuby): onnxscript raises Unsupported expression type
4677-
return op.Transpose(self, [dim0, dim1])
4689+
# Script the function when this is fixed
4690+
self_rank = op.Size(op.Shape(self))
4691+
4692+
if self_rank == 0:
4693+
result = self
4694+
else:
4695+
# Python code, change when onnxscript supports this
4696+
self_rank_val = self_rank.value # type: ignore[attr-defined]
4697+
dims = list(range(self_rank_val))
4698+
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
4699+
# Python code ends
4700+
4701+
result = op.Transpose(self, perm=dims)
4702+
4703+
return result
46784704

46794705

46804706
def aten_triangular_solve(

onnxscript/function_libs/torch_aten/registration.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,33 @@ def __repr__(self):
4848

4949

5050
def torch_op(
51-
name, overload: bool = False, registry: Optional[Registry] = None
52-
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
53-
"""Register a torch op."""
51+
name,
52+
*,
53+
overload: bool = False,
54+
registry: Optional[Registry] = None,
55+
trace_only: bool = False,
56+
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]:
57+
"""Register a torch op.
58+
59+
Args:
60+
name: ATen name of the function. E.g. "aten::add".
61+
overload: Whether the function is an overload (not default).
62+
registry: Registry to register the function to. If None, the default registry is used.
63+
trace_only: Whether the function should only be traced and not compiled.
64+
"""
5465
if registry is None:
5566
registry = default_registry
5667

57-
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction:
68+
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]:
5869

59-
# Compile the function
60-
compiled = onnxscript.script()(func)
70+
if trace_only:
71+
processed_func = func
72+
else:
73+
# Compile the function
74+
processed_func = onnxscript.script()(func)
6175

6276
assert registry is not None
63-
registry.register(compiled, name, overload=overload)
64-
return compiled
77+
registry.register(processed_func, name, overload=overload)
78+
return processed_func
6579

6680
return wrapper

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
182182
OPINFO_FUNCTION_MAPPING: dict[
183183
str,
184184
onnxscript.OnnxFunction
185+
| Callable[..., Any]
185186
| tuple[
186187
onnxscript.OnnxFunction | Callable[..., Any],
187188
Callable[[dict[str, Any]], dict[str, Any]],
@@ -213,6 +214,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
213214
# TODO(justinchuby): Test aten::full
214215
"full_like": core_ops.aten_full_like,
215216
"gt": core_ops.aten_gt,
217+
"index_select": core_ops.aten_index_select,
216218
"isinf": core_ops.aten_isinf,
217219
"lt": core_ops.aten_lt,
218220
"matmul": core_ops.aten_matmul,
@@ -252,7 +254,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
252254
"t": core_ops.aten_t,
253255
"tan": core_ops.aten_tan,
254256
"tanh": core_ops.aten_tanh,
255-
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed,
257+
"transpose": core_ops.aten_transpose,
256258
"unsqueeze": core_ops.aten_unsqueeze,
257259
"where": core_ops.aten_where,
258260
"zeros": core_ops.aten_zeros,
@@ -276,7 +278,6 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
276278
xfail("round", variant_name="decimals_0", reason="The op does not support decimals"),
277279
xfail("round", variant_name="decimals_3", reason="The op does not support decimals"),
278280
xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"),
279-
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
280281
)
281282

282283

@@ -467,6 +468,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
467468
rtol = None
468469
atol = None
469470

471+
if not isinstance(function_output, np.ndarray):
472+
# An onnxscript tensor
473+
function_output = function_output.value
474+
470475
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
471476
torch.testing.assert_close(
472477
torch.tensor(function_output),

0 commit comments

Comments
 (0)