Skip to content

Optimizer constant folding turns bfloat16 initializers into UINT16 #2187

Closed
@justinchuby

Description

@justinchuby
import torch


class BfloatModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16))

    def forward(self, x):
        return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.param

input = torch.randn(1, 10, dtype=torch.bfloat16)
onnx_program = torch.onnx.export(
    BfloatModel(),
    (input,),
    dynamo=True,
    optimize=True
)

print(onnx_program)

print(onnx_program(input))

yields

ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 18},
            producer_name='pytorch',
            producer_version='2.8.0a0+gitcc185c3',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<BFLOAT16,[1,10]>
            ),
            outputs=(
                %"mul_1"<BFLOAT16,[1,10]>
            ),
            initializers=(
                %"param"<BFLOAT16,[]>,
                %"clone"<UINT16,[]>
            ),
        ) {
            0 |  # node_Mul_1
                 %"mul"<BFLOAT16,[1,10]> ⬅️ ::Mul(%"x", %"clone")
            1 |  # node_Mul_2
                 %"mul_1"<BFLOAT16,[1,10]> ⬅️ ::Mul(%"mul", %"param")
            return %"mul_1"<BFLOAT16,[1,10]>
        }


    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_param: "bf16[]", c_lifted_tensor_0: "bf16[]", x: "bf16[1, 10]"):
                     # File: /home/justinchu/dev/pytorch/test_bfloat.py:10 in forward, code: return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.param
                    clone: "bf16[]" = torch.ops.aten.clone.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
                    mul: "bf16[1, 10]" = torch.ops.aten.mul.Tensor(x, clone);  x = clone = None
                    mul_1: "bf16[1, 10]" = torch.ops.aten.mul.Tensor(mul, p_param);  mul = p_param = None
                    return (mul_1,)
            
        Graph signature: 
            # inputs
            p_param: PARAMETER target='param'
            c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
            x: USER_INPUT
    
            # outputs
            mul_1: USER_OUTPUT
    
        Range constraints: {}

)

Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/test_bfloat.py", line 22, in <module>
    print(onnx_program(input))
          ^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 148, in __call__
    self.initialize_inference_session()
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 324, in initialize_inference_session
    self._inference_session = initializer(model)
                              ^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 51, in _ort_session_initializer
    return ort.InferenceSession(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 465, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 528, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Mul) bound to different types (tensor(bfloat16) and tensor(uint16) in node (node_Mul_1).

note %"clone"<UINT16,[]>

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions