Skip to content

Linear from PyTorch must map to Gemm in ONNX #1089

@baijumeswani

Description

@baijumeswani

PyTorch Model:

class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, input1):
        out = self.fc1(input1)
        out = self.relu(out)
        out = self.fc2(out)
        return out

Exporting with torch script based exporter yields:

image

which makes sense. It is after all a linear layer followed by a ReLU followed by another Linear layer.

Exporting the same model with torch dynamo based exporter yields:

image

Two levels beneath the linear layer, I find:

image

It seems like the Gemm is somehow manifested as a subgraph with matmuls, muls, adds, and castlikes. And digging deeper, I find that this definition comes from

@torch_op("aten::addmm")
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
return op.Add(scaled_self, scaled_mat1_mat2)

It seems wasteful that an op as simple as a Gemm needs to be represented as this subgraph. Looking at this document, this seems to be a design choice.

We favor general ops like MatMul than specialized ops like Gemm in the function lib.

But imagine a model having thousands of Gemms. Each Gemm is now this subgraph. Which means this optimization/fusion needs to run thousands of times to achieve something that probably can be achieved very easily at the source.

It would benefit ONNX Runtime (inference and training) and the larger ONNX community if this subgraph were represented as a Gemm node after export.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: torchlibRelated to the torch/aten function lib in developmenttopic: discussionFor discussion

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions