Closed
Description
Thanks for developing onnxscript 💯
I'm currently comparing the fusion options integrated in onnxruntime and onnxscript. As such, I came across fusion of bias and gelu (BiasGeluFusion
in onnxscript and FusionBiasGelu
in onnxruntime).
I noticed, that the domain for Gelu
in the rules implementation is restricted to the contributor ops implementation and does not fuse Gelu
from onnx ops (introduced with opset 20). Onnxruntime, on the other hand, would fuse such operations into BiasGelu
.
Is there a reason, why the implementation is constrained to match contributor ops? Would you be open to match against alternative patterns with Gelu from onnx ops?
MWE:
"""Testing bias + Gelu fusion.
"""
import numpy as np
import onnxscript
from onnxscript import opset20 as op
from onnxscript import script
from onnxscript.onnx_types import FLOAT
from onnxscript.rewriter.ort_fusions.bias_gelu import bias_gelu_rules
import onnxruntime as ort
# compare onnx + onnx-contributor implementation
msft_op = onnxscript.values.Opset("com.microsoft", 1)
@script()
def GemmGelu(Input: FLOAT["N", "K"], Weight: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]:
"""
Gemm + Gelu from onnx domain.
"""
T1 = op.MatMul(Input, Weight)
T2 = op.Add(T1, Bias)
return op.Gelu(T2, approximate="none")
@script()
def GemmGeluCustom(Input: FLOAT["N", "K"], Weight: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]:
"""
Gemm with Gelu from Microsoft domain.
"""
T1 = op.MatMul(Input, Weight)
T2 = op.Add(T1, Bias)
return msft_op.Gelu(T2)
onnx_model = GemmGelu.to_model_proto()
onnx_model = onnxscript.rewriter.rewrite(onnx_model, pattern_rewrite_rules=bias_gelu_rules)
print(onnx_model)
# test model
sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
inputs = {
"Input": np.random.randint(10, size=(3, 8)).astype(np.float32),
"Weight": np.random.randint(10, size=(8, 4)).astype(np.float32),
"Bias": np.random.randint(10, size=(4)).astype(np.float32),
}
sess.run(None, inputs)