Skip to content

feat(atenlib): index_select; trace_only ops #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,10 +2165,21 @@ def aten_index_reduce(
raise NotImplementedError()


def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType:
# FIXME(#277): Script when attributes can come before inputs
@torch_op("aten::index_select", trace_only=True)
def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor:
# index_select(Tensor self, int dim, Tensor index) -> Tensor

raise NotImplementedError()
if op.Size(op.Shape(self)) == 0:
result = self
else:
# Index can be a scalar. Reshape it to a rank 1 tensor.
index = op.Reshape(index, op.Constant(value_floats=[-1]))
index = op.Cast(index, to=INT64.dtype)

result = op.Gather(self, index, axis=dim)

return result


def aten_index_select_backward(
Expand Down Expand Up @@ -4670,11 +4681,26 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


@torch_op("aten::transpose", trace_only=True)
def aten_transpose(self, dim0: int, dim1: int):
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)

# FIXME(justinchuby): onnxscript raises Unsupported expression type
return op.Transpose(self, [dim0, dim1])
# Script the function when this is fixed
self_rank = op.Size(op.Shape(self))

if self_rank == 0:
result = self
else:
# Python code, change when onnxscript supports this
self_rank_val = self_rank.value # type: ignore[attr-defined]
dims = list(range(self_rank_val))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
# Python code ends

result = op.Transpose(self, perm=dims)

return result


def aten_triangular_solve(
Expand Down
30 changes: 22 additions & 8 deletions onnxscript/function_libs/torch_aten/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,33 @@ def __repr__(self):


def torch_op(
name, overload: bool = False, registry: Optional[Registry] = None
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
"""Register a torch op."""
name,
*,
overload: bool = False,
registry: Optional[Registry] = None,
trace_only: bool = False,
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]:
"""Register a torch op.

Args:
name: ATen name of the function. E.g. "aten::add".
overload: Whether the function is an overload (not default).
registry: Registry to register the function to. If None, the default registry is used.
trace_only: Whether the function should only be traced and not compiled.
"""
if registry is None:
registry = default_registry

def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction:
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]:

# Compile the function
compiled = onnxscript.script()(func)
if trace_only:
processed_func = func
else:
# Compile the function
processed_func = onnxscript.script()(func)

assert registry is not None
registry.register(compiled, name, overload=overload)
return compiled
registry.register(processed_func, name, overload=overload)
return processed_func

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
OPINFO_FUNCTION_MAPPING: dict[
str,
onnxscript.OnnxFunction
| Callable[..., Any]
| tuple[
onnxscript.OnnxFunction | Callable[..., Any],
Callable[[dict[str, Any]], dict[str, Any]],
Expand Down Expand Up @@ -213,6 +214,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
# TODO(justinchuby): Test aten::full
"full_like": core_ops.aten_full_like,
"gt": core_ops.aten_gt,
"index_select": core_ops.aten_index_select,
"isinf": core_ops.aten_isinf,
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
Expand Down Expand Up @@ -252,7 +254,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"t": core_ops.aten_t,
"tan": core_ops.aten_tan,
"tanh": core_ops.aten_tanh,
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed,
"transpose": core_ops.aten_transpose,
"unsqueeze": core_ops.aten_unsqueeze,
"where": core_ops.aten_where,
"zeros": core_ops.aten_zeros,
Expand All @@ -276,7 +278,6 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
xfail("round", variant_name="decimals_0", reason="The op does not support decimals"),
xfail("round", variant_name="decimals_3", reason="The op does not support decimals"),
xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"),
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
)


Expand Down Expand Up @@ -467,6 +468,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
rtol = None
atol = None

if not isinstance(function_output, np.ndarray):
# An onnxscript tensor
function_output = function_output.value

# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
torch.testing.assert_close(
torch.tensor(function_output),
Expand Down