Skip to content

Differences between PyTorch and Flax model predictions #963

@JamesAllingham

Description

@JamesAllingham

Hi all,

Here is a bit of a weird one that has me stumped!

Problem you have encountered:

I've been trying to convert ResNet models trained in PyTorch to Flax and I've run into a strange issue. The PyTorch and Flax models tend to make the same class predictions but the class probabilities are slightly different. I've realised that each layer is producing slightly different outputs and that the errors slowly add up. In the case of a ResNet18, the difference in the final prediction is minimal but for a ResNet50 it can mean that the incorrect class is predicted!

This discrepancy also occurs of a single dense layer (see my minimal reproducible example below).

What you expected to happen:

I would expect that the same model architecture with the same weights and the same input would produce exactly the same output.

Steps to reproduce:

Here is a short python script that demonstrates the problem:

from jax import numpy as jnp

from flax import linen
from flax.core import freeze

import torch

import numpy as np

torch.manual_seed(0)

torch_dense = torch.nn.Linear(2048, 10)
jax_dense = linen.Dense(10)

x = jnp.ones((2048,), dtype=np.float32)

jax_params = freeze({"params": {k if k == "bias" else "kernel": v.numpy().T for k, v in torch_dense.state_dict().items()}})
y_jax = jax_dense.apply(jax_params, x)

y_pytorch = torch_dense(torch.Tensor(np.ones((1, 2048), dtype=np.float32))).detach().numpy()[0]

print(f"2-norm of jax pred:\t {np.sqrt(np.mean((y_jax)**2)):.10f}")
print(f"2-norm of pytorch pred:\t {np.sqrt(np.mean((y_pytorch)**2)):.10f}")
print(f"2-norm of diff in preds: {np.sqrt(np.mean((y_jax-y_pytorch)**2)):.10f}")

print("Comparing the weight matrix...")
state = torch_dense.state_dict()["weight"]
numpy_state = state.numpy()
print(f"torch vs np:\t{torch.sqrt(torch.mean((state - numpy_state)**2)):.30f}")
print(f"np vs jax.np:\t{jnp.sqrt(jnp.mean((numpy_state - jnp.array(numpy_state))**2)):.30f}")

which outputs:

2-norm of jax pred:      0.5016138554
2-norm of pytorch pred:  0.5016137958
2-norm of diff in preds: 0.0000001995
Comparing the weight matrix...
torch vs np:    0.000000000000000000000000000000
np vs jax.np:   0.000000000000000000000000000000

My Python/Jax/Flax/etc. versions are:

  • Python: 3.7.9 (default, Aug 31 2020, 12:42:55) [GCC 7.3.0] :: Anaconda, Inc. on linux
  • jax: 0.2.9
  • jaxlib: 0.1.59 (+cuda101)
  • flax: 0.3.0
  • torch: 1.6.0
  • numpy: 1.17.2

EDIT: Here is a colab notebook for good measure :)

Strangely the difference in predictions is slightly different:

2-norm of jax pred:	 0.5016137958
2-norm of pytorch pred:	 0.5016138554
2-norm of diff in preds: 0.0000004255
Comparing the weight matrix...
torch vs np:	0.000000000000000000000000000000
np vs jax.np:	0.000000000000000000000000000000

Metadata

Metadata

Assignees

Labels

Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions