diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 53b6fb62fc..90b1ae453e 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2165,10 +2165,21 @@ def aten_index_reduce( raise NotImplementedError() -def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType: +# FIXME(#277): Script when attributes can come before inputs +@torch_op("aten::index_select", trace_only=True) +def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor - raise NotImplementedError() + if op.Size(op.Shape(self)) == 0: + result = self + else: + # Index can be a scalar. Reshape it to a rank 1 tensor. + index = op.Reshape(index, op.Constant(value_floats=[-1])) + index = op.Cast(index, to=INT64.dtype) + + result = op.Gather(self, index, axis=dim) + + return result def aten_index_select_backward( @@ -4670,11 +4681,26 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() +@torch_op("aten::transpose", trace_only=True) def aten_transpose(self, dim0: int, dim1: int): # transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) # FIXME(justinchuby): onnxscript raises Unsupported expression type - return op.Transpose(self, [dim0, dim1]) + # Script the function when this is fixed + self_rank = op.Size(op.Shape(self)) + + if self_rank == 0: + result = self + else: + # Python code, change when onnxscript supports this + self_rank_val = self_rank.value # type: ignore[attr-defined] + dims = list(range(self_rank_val)) + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + # Python code ends + + result = op.Transpose(self, perm=dims) + + return result def aten_triangular_solve( diff --git a/onnxscript/function_libs/torch_aten/registration.py b/onnxscript/function_libs/torch_aten/registration.py index 01c75b9f4e..86ef7b4835 100644 --- a/onnxscript/function_libs/torch_aten/registration.py +++ b/onnxscript/function_libs/torch_aten/registration.py @@ -48,19 +48,33 @@ def __repr__(self): def torch_op( - name, overload: bool = False, registry: Optional[Registry] = None -) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]: - """Register a torch op.""" + name, + *, + overload: bool = False, + registry: Optional[Registry] = None, + trace_only: bool = False, +) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]: + """Register a torch op. + + Args: + name: ATen name of the function. E.g. "aten::add". + overload: Whether the function is an overload (not default). + registry: Registry to register the function to. If None, the default registry is used. + trace_only: Whether the function should only be traced and not compiled. + """ if registry is None: registry = default_registry - def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction: + def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]: - # Compile the function - compiled = onnxscript.script()(func) + if trace_only: + processed_func = func + else: + # Compile the function + processed_func = onnxscript.script()(func) assert registry is not None - registry.register(compiled, name, overload=overload) - return compiled + registry.register(processed_func, name, overload=overload) + return processed_func return wrapper diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index afeb51a05f..d5d167ab83 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -182,6 +182,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: OPINFO_FUNCTION_MAPPING: dict[ str, onnxscript.OnnxFunction + | Callable[..., Any] | tuple[ onnxscript.OnnxFunction | Callable[..., Any], Callable[[dict[str, Any]], dict[str, Any]], @@ -213,6 +214,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: # TODO(justinchuby): Test aten::full "full_like": core_ops.aten_full_like, "gt": core_ops.aten_gt, + "index_select": core_ops.aten_index_select, "isinf": core_ops.aten_isinf, "lt": core_ops.aten_lt, "matmul": core_ops.aten_matmul, @@ -252,7 +254,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, - # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed, + "transpose": core_ops.aten_transpose, "unsqueeze": core_ops.aten_unsqueeze, "where": core_ops.aten_where, "zeros": core_ops.aten_zeros, @@ -276,7 +278,6 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: xfail("round", variant_name="decimals_0", reason="The op does not support decimals"), xfail("round", variant_name="decimals_3", reason="The op does not support decimals"), xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"), - xfail("transpose", reason="Enable when onnxscript errors are fixed"), ) @@ -467,6 +468,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None + if not isinstance(function_output, np.ndarray): + # An onnxscript tensor + function_output = function_output.value + # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output),