Skip to content

feat(atenlib): ops 2/n #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
691772b
feat(atenlib): ops 2/n
justinchuby Dec 13, 2022
f160dfa
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c8e4a54
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7cd967d
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
e8f07c9
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c4c80a1
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7c0e305
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
2db3170
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
b7b03ee
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
5216435
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
93ed77b
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
a5c9629
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
57676c8
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 15, 2022
3967967
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2680,10 +2680,10 @@ def aten_masked_select_backward(
raise NotImplementedError()


def aten_matmul(self: TensorType, other: TensorType) -> TensorType:
def aten_matmul(self, other):
# matmul(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.MatMul(self, other)


def aten_matmul_backward(
Expand Down Expand Up @@ -3080,10 +3080,11 @@ def aten_mkldnn_max_pool3d_backward(
raise NotImplementedError()


def aten_mm(self: TensorType, mat2: TensorType) -> TensorType:
def aten_mm(self, mat2):
# mm(Tensor self, Tensor mat2) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Specify type conversion for uint8/int8/int16
return op.MatMul(self, mat2)


def aten_mode(
Expand Down Expand Up @@ -3463,16 +3464,13 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
raise NotImplementedError()


def aten_numpy_T(self: TensorType) -> TensorType:
# numpy_T(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()


def aten_ones(size: INT64) -> TensorType:
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
one = op.Constant(value_float=1)
if dtype != -1:
one = op.Cast(one, to=dtype) # type: ignore[arg-type]
return op.Expand(one, size) # type: ignore[arg-type]


def aten_ones_like(self, dtype: int = -1):
Expand Down Expand Up @@ -4461,7 +4459,13 @@ def aten_symeig(
def aten_t(self: TensorType) -> TensorType:
# t(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()
# TODO(justinchuby): Make rank a function
rank = op.Size(op.Shape(self)) # type: ignore[arg-type]
if rank == 0 or rank == 1: # pylint: disable=consider-using-in
result = self
else:
result = op.Transpose(self, perm=[1, 0]) # type: ignore[arg-type]
return result


def aten_t_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4606,6 +4610,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


def aten_transpose(self, dim0: int, dim1: int):
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)

# FIXME(justinchuby): onnxscript raises Unsupported expression type
return op.Transpose(self, [dim0, dim1])


def aten_triangular_solve(
self: TensorType,
A: TensorType,
Expand Down
17 changes: 13 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,21 @@ def aten_leaky_relu_backward(
raise NotImplementedError()


def aten_linear(
input: TensorType, weight: TensorType, bias: Optional[TensorType] = None
) -> TensorType:
def aten_linear(input, weight, bias=None) -> TensorType:
# linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor

raise NotImplementedError()
# FIXME(justinchuby): Enable the test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current PyTorch exporter, if the rank of input is 2 and the node of bias is not None, it will return a result of addmm(). Code is here.

Why do we remove that logic here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can leave this to downstream optimization to keep the graph as simple and expressive as possible. What do you think?

Also addmm is an aten op. I think we should avoid calling other ATen ops and instead extract logic to common functions only when needed (which for now is not well supported by onnxscript, so we don’t call other functions ever yet)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the optimization purpose, then I think we should add some comments here to mark this point so that we won't forget it when we start the optimization work.

Even these functions were named with an aten_ prefix, they are still ONNXScript functions which are essentially same to those op.Functions. I think we don't need to avoid calling each other within this aten functions lib. Do you have any examples/ideas to describe the cons?

Copy link
Collaborator Author

@justinchuby justinchuby Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) Since they have ATen signatures, we have seen many cases where additional params need to be provided to make the call correct. (2) Conceptually, the ATen functions interface with ATen ops (external facing). And so calling them internally overloads the responsibility, which I hope to maintain a clear conceptual model/boundary of.

This also helps with the representation of the exported graph. When users see an ATen function, they will know there is a corresponding op in their model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

# INVALID_GRAPH : This is an invalid model.
# In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) ,
# Error Node () has input size 0 not in range [min=1, max=1]

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
result = op.MatMul(input, weight)
if op.OptionalHasElement(bias):
bias = op.OptionalGetElement(bias)
result = op.Add(result, bias) # type: ignore[arg-type]
return result


def aten_log_sigmoid(self: TensorType) -> TensorType:
Expand Down
39 changes: 38 additions & 1 deletion onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,35 +162,56 @@ def wrapped(fn):
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
"add": core_ops.aten_add,
# "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
"clamp_max": core_ops.aten_clamp_max_tensor,
"clamp_min": core_ops.aten_clamp_min_tensor,
"clamp": core_ops.aten_clamp,
"gt": core_ops.aten_gt,
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
"mm": core_ops.aten_mm,
"mul": core_ops.aten_mul,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.linear": nn_ops.aten_linear,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"ones_like": core_ops.aten_ones_like,
"ones": core_ops.aten_ones,
"repeat": core_ops.aten_repeat,
"round": core_ops.aten_round,
"sub": core_ops.aten_sub,
"t": core_ops.aten_t,
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
skip("clamp", reason="Enable when onnxscript errors are fixed"),
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
xfail(
"matmul",
dtypes=[torch.uint8, torch.int8, torch.int16],
reason="MatMul is not defined on int16/int8/uint8 tensors",
),
xfail(
"mm",
dtypes=[torch.uint8, torch.int8, torch.int16],
reason="MatMul is not defined on int16/int8/uint8 tensors",
),
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
xfail(
"nn.functional.elu",
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Elu",
),
xfail(
"nn.functional.linear",
reason="ONNX Runtime thinks the graph is invalid",
),
xfail(
"nn.functional.relu6",
dtypes=dtypes_except(torch.float16, torch.float32),
Expand All @@ -213,6 +234,7 @@ def wrapped(fn):
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
),
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
)


Expand Down Expand Up @@ -240,6 +262,10 @@ def wrapped(fn):

OPS_DB = copy.deepcopy(common_methods_invocations.op_db)

ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"


TORCH_TYPE_TO_ONNX = {
torch.bool: onnx.TensorProto.BOOL,
Expand Down Expand Up @@ -369,10 +395,21 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
)
# pylint: enable=c-extension-no-member

if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
# The current most relaxed values are for aten::matmul
rtol = 3.7e-6
atol = 1.8e-5
else:
rtol = None
atol = None

# Use torch testing to ensure dtypes and shapes match
torch.testing.assert_close(
torch.tensor(function_output),
output_torch,
rtol=rtol,
atol=atol,
)


Expand Down