Skip to content

NaN Outputs in ONNX Runtime When Weights Initialized with large constants #20429

@daniyalaliev

Description

@daniyalaliev

Describe the issue

The current implementation encounters an issue when initializing weights with a uniform constant (e.g., any value >=0.1); upon execution with ONNX Runtime, it results in a tensor of NaN values. However, the corresponding PyTorch model processes this setup without any issues.

Conversely, using a smaller constant value or the standard random initialization for weights allows the model to operate correctly. Under these conditions, the output of the ONNX Runtime model aligns with that of the PyTorch model.

To reproduce

from typing import *
import torch
import transformers

class Model(torch.nn.Module):
    def __init__(self, ):
        super().__init__()
        self.encoder = transformers.AutoModel.from_pretrained('roberta-base')

    def forward(self,
                input_ids,
                attention_mask
                ):
        embeddings = self.encoder(input_ids=input_ids,
                                  attention_mask=attention_mask).last_hidden_state
        return embeddings

torch_model = Model()
input_ids = torch.ones((1, 512), dtype=torch.int64)
attention_mask = torch.torch.ones_like(input_ids, dtype=torch.bool)

torch_model.eval()
outputs = torch_model.forward(input_ids, attention_mask)

for param in torch_model.parameters():
    torch.nn.init.constant_(param, float(0.1))

torch.onnx.export(
    torch_model,
    (input_ids, attention_mask),
    input_names=['input_ids', 'attention_mask'],
    output_names=['outputs'],
    f='./bert_segm1.onnx',
    do_constant_folding=True,
    export_params=True,
    opset_version=12,
    dynamic_axes={
        'input_ids':      {0: 'batch_size'},
        'outputs':        {0: 'batch_size',}
    }
)


print(outputs) # pytorch output is OK

# tensor([[[-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200],
#          [-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200],
#          [-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200],
#          ...,
#          [-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200],
#          [-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200],
#          [-0.0671,  0.0620, -0.0265,  ..., -0.1117, -0.0480, -0.0200]]],
#        grad_fn=<NativeLayerNormBackward0>)

import onnxruntime
import numpy as np
import onnx
model = onnx.load('bert_segm1.onnx')
onnx.checker.check_model(model, full_check=True)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_session = onnxruntime.InferenceSession("./bert_segm1.onnx", providers=['CPUExecutionProvider'])
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_ids),
              ort_session.get_inputs()[1].name: to_numpy(attention_mask)}

onnxruntime_outputs = ort_session.run(None, ort_inputs)

print(onnxruntime_outputs) # Here prints NaN tensor

# [array([[[nan, nan, nan, ..., nan, nan, nan],
#         [nan, nan, nan, ..., nan, nan, nan],
#         [nan, nan, nan, ..., nan, nan, nan],
#         ...,
#         [nan, nan, nan, ..., nan, nan, nan],
#         [nan, nan, nan, ..., nan, nan, nan],
#         [nan, nan, nan, ..., nan, nan, nan]]], dtype=float32)]

Urgency

No response

Platform

Windows

OS Version

10

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions