diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 377e25eb86..5dbe60d931 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2680,10 +2680,10 @@ def aten_masked_select_backward( raise NotImplementedError() -def aten_matmul(self: TensorType, other: TensorType) -> TensorType: +def aten_matmul(self, other): # matmul(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + return op.MatMul(self, other) def aten_matmul_backward( @@ -3080,10 +3080,11 @@ def aten_mkldnn_max_pool3d_backward( raise NotImplementedError() -def aten_mm(self: TensorType, mat2: TensorType) -> TensorType: +def aten_mm(self, mat2): # mm(Tensor self, Tensor mat2) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Specify type conversion for uint8/int8/int16 + return op.MatMul(self, mat2) def aten_mode( @@ -3463,16 +3464,13 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType: raise NotImplementedError() -def aten_numpy_T(self: TensorType) -> TensorType: - # numpy_T(Tensor(a) self) -> Tensor(a) - - raise NotImplementedError() - - -def aten_ones(size: INT64) -> TensorType: +def aten_ones(size: INT64, dtype: int = -1) -> TensorType: # ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - raise NotImplementedError() + one = op.Constant(value_float=1) + if dtype != -1: + one = op.Cast(one, to=dtype) # type: ignore[arg-type] + return op.Expand(one, size) # type: ignore[arg-type] def aten_ones_like(self, dtype: int = -1): @@ -4461,7 +4459,13 @@ def aten_symeig( def aten_t(self: TensorType) -> TensorType: # t(Tensor(a) self) -> Tensor(a) - raise NotImplementedError() + # TODO(justinchuby): Make rank a function + rank = op.Size(op.Shape(self)) # type: ignore[arg-type] + if rank == 0 or rank == 1: # pylint: disable=consider-using-in + result = self + else: + result = op.Transpose(self, perm=[1, 0]) # type: ignore[arg-type] + return result def aten_t_copy(self: TensorType) -> TensorType: @@ -4606,6 +4610,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() +def aten_transpose(self, dim0: int, dim1: int): + # transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + + # FIXME(justinchuby): onnxscript raises Unsupported expression type + return op.Transpose(self, [dim0, dim1]) + + def aten_triangular_solve( self: TensorType, A: TensorType, diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index a88874183d..d76250ea8a 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -407,12 +407,21 @@ def aten_leaky_relu_backward( raise NotImplementedError() -def aten_linear( - input: TensorType, weight: TensorType, bias: Optional[TensorType] = None -) -> TensorType: +def aten_linear(input, weight, bias=None) -> TensorType: # linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - raise NotImplementedError() + # FIXME(justinchuby): Enable the test + # INVALID_GRAPH : This is an invalid model. + # In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) , + # Error Node () has input size 0 not in range [min=1, max=1] + + # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases + # Optimizers may consider this path and replace it with Gemm + result = op.MatMul(input, weight) + if op.OptionalHasElement(bias): + bias = op.OptionalGetElement(bias) + result = op.Add(result, bias) # type: ignore[arg-type] + return result def aten_log_sigmoid(self: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index a2a9dbac7c..7c9a67496d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -162,35 +162,56 @@ def wrapped(fn): # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { "add": core_ops.aten_add, - # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable "clamp_max": core_ops.aten_clamp_max_tensor, "clamp_min": core_ops.aten_clamp_min_tensor, + "clamp": core_ops.aten_clamp, "gt": core_ops.aten_gt, "lt": core_ops.aten_lt, + "matmul": core_ops.aten_matmul, + "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, "nn.functional.elu": nn_ops.aten_elu, + "nn.functional.linear": nn_ops.aten_linear, "nn.functional.relu6": nn_ops.aten_relu6, "nn.functional.selu": core_ops.aten_selu, "ones_like": core_ops.aten_ones_like, + "ones": core_ops.aten_ones, "repeat": core_ops.aten_repeat, "round": core_ops.aten_round, "sub": core_ops.aten_sub, + "t": core_ops.aten_t, + # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed } TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) EXPECTED_SKIPS_OR_FAILS = ( xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"), + skip("clamp", reason="Enable when onnxscript errors are fixed"), xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"), xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"), xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"), + xfail( + "matmul", + dtypes=[torch.uint8, torch.int8, torch.int16], + reason="MatMul is not defined on int16/int8/uint8 tensors", + ), + xfail( + "mm", + dtypes=[torch.uint8, torch.int8, torch.int16], + reason="MatMul is not defined on int16/int8/uint8 tensors", + ), xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"), xfail( "nn.functional.elu", dtypes=dtypes_except(torch.float16, torch.float32), reason="ONNX Runtime doesn't support float64 for Elu", ), + xfail( + "nn.functional.linear", + reason="ONNX Runtime thinks the graph is invalid", + ), xfail( "nn.functional.relu6", dtypes=dtypes_except(torch.float16, torch.float32), @@ -213,6 +234,7 @@ def wrapped(fn): "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), + xfail("transpose", reason="Enable when onnxscript errors are fixed"), ) @@ -240,6 +262,10 @@ def wrapped(fn): OPS_DB = copy.deepcopy(common_methods_invocations.op_db) +ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) +# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB +assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" + TORCH_TYPE_TO_ONNX = { torch.bool: onnx.TensorProto.BOOL, @@ -369,10 +395,21 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): ) # pylint: enable=c-extension-no-member + if dtype == torch.float32: + # Relax atol and rtol for float32 based on empirical results + # The current most relaxed values are for aten::matmul + rtol = 3.7e-6 + atol = 1.8e-5 + else: + rtol = None + atol = None + # Use torch testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output), output_torch, + rtol=rtol, + atol=atol, )