Skip to content

LayerNorm node has data type mismatches for input and scale #2099

Closed
@yuanyao-nv

Description

@yuanyao-nv

When exporting the flux transformer model the LayerNorm op has mismatching dtypes for input and scale. See onnxruntime error:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from /root/transformer.onnx failed:Type Error: Type parameter (T) of Optype (LayerNormalization) bound to different types (tensor(float16) and tensor(float) in node (n0_16).

The proto of the problematic LayerNorm node is:

input: "linear_7"
input: "val_344"
input: "val_347"
output: "layer_norm_1"
name: "n0_16"
op_type: "LayerNormalization"
attribute {
  name: "axis"
  i: -1
  type: INT
}
attribute {
  name: "epsilon"
  f: 1e-06
  type: FLOAT
}
metadata_props {
  key: "namespace"
  value: ": diffusers.models.transformers.transformer_flux.FluxTransformer2DModel/transformer_blocks.0: diffusers.models.transformers.transformer_flux.FluxTransformerBlock/transformer_blocks.0.norm1_context: diffusers.models.normalization.AdaLayerNormZero/transformer_blocks.0.norm1_context.norm: torch.nn.modules.normalization.LayerNorm/layer_norm_1: aten.layer_norm.default"
}
metadata_props {
  key: "pkg.torch.onnx.class_hierarchy"
  value: "[\'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel\', \'diffusers.models.transformers.transformer_flux.FluxTransformerBlock\', \'diffusers.models.normalization.AdaLayerNormZero\', \'torch.nn.modules.normalization.LayerNorm\', \'aten.layer_norm.default\']"
}
metadata_props {
  key: "pkg.torch.onnx.fx_node"
  value: "%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%linear_7, [3072], None, None, 1e-06), kwargs = {})"
}
metadata_props {
  key: "pkg.torch.onnx.name_scopes"
  value: "[\'\', \'transformer_blocks.0\', \'transformer_blocks.0.norm1_context\', \'transformer_blocks.0.norm1_context.norm\', \'layer_norm_1\']"
}
metadata_props {
  key: "pkg.torch.onnx.stack_trace"
  value: "File \"/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_flux.py\", line 508, in forward\n    encoder_hidden_states, hidden_states = block(\n  File \"/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_flux.py\", line 175, in forward\n    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(\n  File \"/usr/local/lib/python3.12/dist-packages/diffusers/models/normalization.py\", line 170, in forward\n    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]\n  File \"/usr/local/lib/python3.12/dist-packages/torch/nn/modules/normalization.py\", line 217, in forward\n    return F.layer_norm("
}

Here's the export script

import torch
from diffusers.models import FluxTransformer2DModel

model_name = "black-forest-labs/FLUX.1-dev"
hf_safetensor = True
model_opts = {'torch_dtype': torch.float16}
model = FluxTransformer2DModel.from_pretrained(model_name, subfolder="transformer", use_safetensors=hf_safetensor, force_download=True, **model_opts).to("cuda")

B, latent_dim = 1, 4096
inputs = (
    torch.zeros(B, latent_dim, 64, dtype=torch.float16, device="cuda"),
    torch.zeros(B, 512, 4096, dtype=torch.float16, device="cuda"),
    torch.zeros(B, 768, dtype=torch.float16, device="cuda"),
    torch.zeros(B, dtype=torch.float16, device="cuda"),
    torch.zeros(latent_dim, 3, dtype=torch.float32, device="cuda"),
    torch.zeros(512, 3, dtype=torch.float32, device="cuda"),
    torch.zeros(B, dtype=torch.float32, device="cuda"),
)

# Dynamo
dynamic_shapes = (
    {0:"B", 1:"latent_dim"},
    {0:"B"},
    {0:"B"},
    {0:"B"},
    {0:"latent_dim"},
    None,
    {0:"B"},
)
torch.onnx.export(
    model,
    inputs,
    "transformer.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep", "img_ids", "txt_ids", "guidance"],
    output_names=["latent"],
    dynamic_shapes=dynamic_shapes,
    verbose=False,
    dynamo=True,
)

The model exported from the torchscript exporter doesn't have this problem.

Metadata

Metadata

Assignees

Labels

module: torchlibRelated to the torch/aten function lib in development

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions