Closed
Description
import torch
class BfloatModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16))
def forward(self, x):
return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.param
input = torch.randn(1, 10, dtype=torch.bfloat16)
onnx_program = torch.onnx.export(
BfloatModel(),
(input,),
dynamo=True,
optimize=True
)
print(onnx_program)
print(onnx_program(input))
yields
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0a0+gitcc185c3',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<BFLOAT16,[1,10]>
),
outputs=(
%"mul_1"<BFLOAT16,[1,10]>
),
initializers=(
%"param"<BFLOAT16,[]>,
%"clone"<UINT16,[]>
),
) {
0 | # node_Mul_1
%"mul"<BFLOAT16,[1,10]> ⬅️ ::Mul(%"x", %"clone")
1 | # node_Mul_2
%"mul_1"<BFLOAT16,[1,10]> ⬅️ ::Mul(%"mul", %"param")
return %"mul_1"<BFLOAT16,[1,10]>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_param: "bf16[]", c_lifted_tensor_0: "bf16[]", x: "bf16[1, 10]"):
# File: /home/justinchu/dev/pytorch/test_bfloat.py:10 in forward, code: return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.param
clone: "bf16[]" = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
mul: "bf16[1, 10]" = torch.ops.aten.mul.Tensor(x, clone); x = clone = None
mul_1: "bf16[1, 10]" = torch.ops.aten.mul.Tensor(mul, p_param); mul = p_param = None
return (mul_1,)
Graph signature:
# inputs
p_param: PARAMETER target='param'
c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
x: USER_INPUT
# outputs
mul_1: USER_OUTPUT
Range constraints: {}
)
Traceback (most recent call last):
File "/home/justinchu/dev/pytorch/test_bfloat.py", line 22, in <module>
print(onnx_program(input))
^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 148, in __call__
self.initialize_inference_session()
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 324, in initialize_inference_session
self._inference_session = initializer(model)
^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 51, in _ort_session_initializer
return ort.InferenceSession(
^^^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 465, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 528, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Mul) bound to different types (tensor(bfloat16) and tensor(uint16) in node (node_Mul_1).
note %"clone"<UINT16,[]>