Skip to content

Domain of Gelu in bias_gelu_rules set #2362

Closed
@KarelZe

Description

@KarelZe

Thanks for developing onnxscript 💯

I'm currently comparing the fusion options integrated in onnxruntime and onnxscript. As such, I came across fusion of bias and gelu (BiasGeluFusion in onnxscript and FusionBiasGelu in onnxruntime).

I noticed, that the domain for Gelu in the rules implementation is restricted to the contributor ops implementation and does not fuse Gelu from onnx ops (introduced with opset 20). Onnxruntime, on the other hand, would fuse such operations into BiasGelu.

Is there a reason, why the implementation is constrained to match contributor ops? Would you be open to match against alternative patterns with Gelu from onnx ops?

@shubhambhokare1 @justinchuby

MWE:

"""Testing bias + Gelu fusion.
"""

import numpy as np
import onnxscript
from onnxscript import opset20 as op
from onnxscript import script
from onnxscript.onnx_types import FLOAT
from onnxscript.rewriter.ort_fusions.bias_gelu import bias_gelu_rules

import onnxruntime as ort

# compare onnx + onnx-contributor implementation
msft_op = onnxscript.values.Opset("com.microsoft", 1)


@script()
def GemmGelu(Input: FLOAT["N", "K"], Weight: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]:
    """
    Gemm + Gelu from onnx domain.
    """
    T1 = op.MatMul(Input, Weight)
    T2 = op.Add(T1, Bias)
    return op.Gelu(T2, approximate="none")


@script()
def GemmGeluCustom(Input: FLOAT["N", "K"], Weight: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]:
    """
    Gemm with Gelu from Microsoft domain.
    """
    T1 = op.MatMul(Input, Weight)
    T2 = op.Add(T1, Bias)
    return msft_op.Gelu(T2)


onnx_model = GemmGelu.to_model_proto()
onnx_model = onnxscript.rewriter.rewrite(onnx_model, pattern_rewrite_rules=bias_gelu_rules)

print(onnx_model)

# test model
sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
inputs = {
    "Input": np.random.randint(10, size=(3, 8)).astype(np.float32),
    "Weight": np.random.randint(10, size=(8, 4)).astype(np.float32),
    "Bias": np.random.randint(10, size=(4)).astype(np.float32),
}
sess.run(None, inputs)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions