Skip to content

Linear from PyTorch must map to Gemm in ONNX #1089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
baijumeswani opened this issue Oct 11, 2023 · 13 comments · Fixed by #1111 or #1113
Closed

Linear from PyTorch must map to Gemm in ONNX #1089

baijumeswani opened this issue Oct 11, 2023 · 13 comments · Fixed by #1111 or #1113
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion

Comments

@baijumeswani
Copy link

baijumeswani commented Oct 11, 2023

PyTorch Model:

class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, input1):
        out = self.fc1(input1)
        out = self.relu(out)
        out = self.fc2(out)
        return out

Exporting with torch script based exporter yields:

image

which makes sense. It is after all a linear layer followed by a ReLU followed by another Linear layer.

Exporting the same model with torch dynamo based exporter yields:

image

Two levels beneath the linear layer, I find:

image

It seems like the Gemm is somehow manifested as a subgraph with matmuls, muls, adds, and castlikes. And digging deeper, I find that this definition comes from

@torch_op("aten::addmm")
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
return op.Add(scaled_self, scaled_mat1_mat2)

It seems wasteful that an op as simple as a Gemm needs to be represented as this subgraph. Looking at this document, this seems to be a design choice.

We favor general ops like MatMul than specialized ops like Gemm in the function lib.

But imagine a model having thousands of Gemms. Each Gemm is now this subgraph. Which means this optimization/fusion needs to run thousands of times to achieve something that probably can be achieved very easily at the source.

It would benefit ONNX Runtime (inference and training) and the larger ONNX community if this subgraph were represented as a Gemm node after export.

@baijumeswani
Copy link
Author

cc: @BowenBao @justinchuby

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 12, 2023

Thanks for raising this issue! When we created the decomposition, I realized Gemm is a special case for the op addmm (https://github.com/pytorch/pytorch/blob/a6b452dfdcb484d5dfdbb577b74cecbd7021df2e/torch/onnx/symbolic_opset9.py#L645-L652). In the design of torchlib, we wanted the ONNX functions to mirror the aten ops behavior as closely as possible, so that we preserve the richest information for downstream optimization (doc). To be able to use Gemm for addmm, we need to know the type and rank of the input, which are not assumed to be available at export time.

This kind of fusion should actually be simple for downstream optimization passes by design. We can look at the aten_addmm function, its input types and rank when available, then make the substitution. We do need the type and rank information for this which is not available in nested functions though as @BowenBao pointed out in onnx/onnx#5487

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 12, 2023

However, for this special case, we may be able to create an overload for supported types to conditionally choose Gemm based on rank. Optimization passes will still need to fold if branches for this.

Edit:

I tried (1)

@torch_op("aten::addmm")
def aten_addmm(
    self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
    if use_gemm:
        result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
    else:
        mat1_mat2 = op.MatMul(mat1, mat2)
        scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
        scaled_self = op.Mul(self, beta)
        result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

But apparently ORT has only Gemm for float32 and not other types. So this needs to become (2)

@torch_op("aten::addmm")
def aten_addmm_gemm(
    self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0
) -> FLOAT:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    use_gemm = op.And(op.Equal(op.Size(op.Shape(mat1)), 2), op.Equal(op.Size(op.Shape(mat2)), 2))
    if use_gemm:
        result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
    else:
        mat1_mat2 = op.MatMul(mat1, mat2)
        scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
        scaled_self = op.Mul(self, beta)
        result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

@torch_op("aten::addmm")
def aten_addmm(
    self: TNotFloat32, mat1: TNotFloat32, mat2: TNotFloat32, beta: float = 1.0, alpha: float = 1.0
) -> TNotFloat32:
    """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

    mat1_mat2 = op.MatMul(mat1, mat2)
    scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
    scaled_self = op.Mul(self, beta)
    result = op.Add(scaled_self, scaled_mat1_mat2)
    return result

But since Gemm is defined on {tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)}, it makes less sense for the exporter to make this specialization on FLOAT inputs.

Let me know what you think or if I am missing anything. Thanks!

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 12, 2023

So far it looks like the best path forward is for ORT to implement Gemm on spec'ed types and use (1). This way we strike a balance on correctness, complexity and the effort needed for fusion.

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 12, 2023

Although: if folding is not necessarily easier and the model may run even slower with if branches when unoptimized. The assumption is we don’t want to specialize the function at conversion time so we can’t just use Gemm.

@baijumeswani
Copy link
Author

baijumeswani commented Oct 12, 2023

All solutions offered here are not very helpful since they all require a subgraph computation/optimization (to be either folded away or fused to Gemm).

Ideally, the information about the rank/type of the input matrices as well as the value of alpha and beta are known at export time. Which makes me feel that this should be dealt with at source and not down streamed to another optimization pass at a later time.

This particularly becomes more important for scenarios where the export is an inline operation (such as in ORTModule) and the export time along with other optimization times results in performance penalty for the scenario.

@baijumeswani
Copy link
Author

cc @pranavsharma for awareness, as I think this would impact inference as well.

@pranavsharma
Copy link

pranavsharma commented Oct 12, 2023

Thanks @baijumeswani for adding me.

Exporter team: Please try to fix this at the export time as ORT is not the only consumer of ONNX graphs. There is a whole ecosystem around ONNX and such changes will break all of them.

ORT has not implemented Gemm for certain types because there was no production use case and adding unnecessary types increases the binary size. Hence, it doesn't make sense for ORT to implement ops for all types. For the most frequently used types, can we emit Gemm? This way we're not penalizing the majority use cases.

@justinchuby
Copy link
Collaborator

Thanks for this perspective! Happy to explore options here. One of the things that come to mind is as we build out aot optimization capabilities for ONNX graphs, these type of patterns can be optimized away (by the exporter) before the runtime sees the graph. This way tools in the ecosystem can choose to operate on graphs with different levels of generality based on the assumptions they are built against.

@justinchuby justinchuby added module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion labels Oct 13, 2023
@xadupre
Copy link
Member

xadupre commented Oct 18, 2023

Some thoughts related to these issues.

Models can be very big nowadays, anything we don't handle at exporting handle must be taken care of at optimizing time. It is ok for small models but it is still ok on bigger models with thousands of operators? Looking for patterns in such graphs adds significant time. Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime).

One particular case with onnx-script, it is rare but it can happen.

if beta == 0:
   B = op.Matmul(X, np.array(...))
else:
   B = op.Matmul(X, np.array(....))

onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers. Then an optimization will fold the constants and keep one operator and one initializer. What if both initializers are very big? We would add unnecessary tensors to the model making it unnecessary big.

Another one, again, it is rare but it is possible:

B = op.Matmul(A, op.CastLike(np.array([....], dtype=np.float32), B)

The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 18, 2023

Thanks!

it is still ok on bigger models with thousands of operators

Potentially, since we have functions already, there should be a clear boundary for us to match things?

Maybe we should start tracking the converters performance (converting time, optimizing time with onnxruntime)

A similar thing is tracked at https://github.com/microsoft/onnx-converters-private/issues/166#issuecomment-1764864419 (dort, compilation time) From profiling we have seen the main delay being torch dynamo at the moment.

onnx-script will convert this into 3 operators (if + 2x matmul) and 2 initializers

I think a concreate usage will help discussion here. Since aten operators take all large tensors as input, I don't see we will duplicate large tensors in functions (they are more likely scalars).

The onnx model will always keep float tensors but if the model is float16, this could be reduced by half and the exported model could be smaller.

This I think presents a similar issue, where all castlike'd constants tend to be single element tensors (scalars) that don't take up spaces. The exported initializers will be float16 if the model is dealing with float16 inputs.

@justinchuby justinchuby mentioned this issue Oct 20, 2023
5 tasks
@justinchuby justinchuby self-assigned this Oct 20, 2023
justinchuby added a commit that referenced this issue Oct 25, 2023
Decompose addmm with Gemm by creating a special variant for `FLOAT` and
conditionally check for the ranks if the input tensors. The if branch is
expected to be folded away by constant folding passes.

I have not found other instances where Gemm is used in the torch.onnx
exporter.

Fixes #1089
@github-project-automation github-project-automation bot moved this to Done in My List Oct 25, 2023
@gramalingam
Copy link
Collaborator

A few comments:

  • Seems better for exporter to emit Gemm for all legal types when possible (without worrying about whether ORT has Gemm kernels for those types ... that is more of an ORT issue). That is, option (1) seems preferable over option (2) in Justin's note.
  • It makes sense for the exporter to run some set of standard optimizations as post-processing to hide the complexity discussed above from downstream users to the extent possible.
  • For example, types can be assumed to be known. So, CastLike should be eliminated, constant-folded away, ranks are typically known, and the if-conditions checking ranks can be constant-folded away, etc.

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 25, 2023

Thanks @gramalingam. I can change to (1) in the implementation if there is no objections. Fortunately the rest of the complexity for this op is no longer a concern because we realized Gemm can handle the op. But these points do help when we counter new instances like this

justinchuby added a commit that referenced this issue Oct 25, 2023
When I looked at the test coverage for `addmm` (below), I realized mat1
and mat2 are always 2d tensors. So the rank check is redundant. `addmm`
is now fully mapped to `Gemm`, which should completely resolve
#1089

Closes #1110


![image](https://github.com/microsoft/onnxscript/assets/11205048/073347ca-d677-4c87-94fa-e40a13642569)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion
Projects
None yet
5 participants