-
Notifications
You must be signed in to change notification settings - Fork 722
Description
Hi all,
Here is a bit of a weird one that has me stumped!
Problem you have encountered:
I've been trying to convert ResNet models trained in PyTorch to Flax and I've run into a strange issue. The PyTorch and Flax models tend to make the same class predictions but the class probabilities are slightly different. I've realised that each layer is producing slightly different outputs and that the errors slowly add up. In the case of a ResNet18, the difference in the final prediction is minimal but for a ResNet50 it can mean that the incorrect class is predicted!
This discrepancy also occurs of a single dense layer (see my minimal reproducible example below).
What you expected to happen:
I would expect that the same model architecture with the same weights and the same input would produce exactly the same output.
Steps to reproduce:
Here is a short python script that demonstrates the problem:
from jax import numpy as jnp
from flax import linen
from flax.core import freeze
import torch
import numpy as np
torch.manual_seed(0)
torch_dense = torch.nn.Linear(2048, 10)
jax_dense = linen.Dense(10)
x = jnp.ones((2048,), dtype=np.float32)
jax_params = freeze({"params": {k if k == "bias" else "kernel": v.numpy().T for k, v in torch_dense.state_dict().items()}})
y_jax = jax_dense.apply(jax_params, x)
y_pytorch = torch_dense(torch.Tensor(np.ones((1, 2048), dtype=np.float32))).detach().numpy()[0]
print(f"2-norm of jax pred:\t {np.sqrt(np.mean((y_jax)**2)):.10f}")
print(f"2-norm of pytorch pred:\t {np.sqrt(np.mean((y_pytorch)**2)):.10f}")
print(f"2-norm of diff in preds: {np.sqrt(np.mean((y_jax-y_pytorch)**2)):.10f}")
print("Comparing the weight matrix...")
state = torch_dense.state_dict()["weight"]
numpy_state = state.numpy()
print(f"torch vs np:\t{torch.sqrt(torch.mean((state - numpy_state)**2)):.30f}")
print(f"np vs jax.np:\t{jnp.sqrt(jnp.mean((numpy_state - jnp.array(numpy_state))**2)):.30f}")
which outputs:
2-norm of jax pred: 0.5016138554
2-norm of pytorch pred: 0.5016137958
2-norm of diff in preds: 0.0000001995
Comparing the weight matrix...
torch vs np: 0.000000000000000000000000000000
np vs jax.np: 0.000000000000000000000000000000
My Python/Jax/Flax/etc. versions are:
- Python: 3.7.9 (default, Aug 31 2020, 12:42:55) [GCC 7.3.0] :: Anaconda, Inc. on linux
- jax: 0.2.9
- jaxlib: 0.1.59 (+cuda101)
- flax: 0.3.0
- torch: 1.6.0
- numpy: 1.17.2
EDIT: Here is a colab notebook for good measure :)
Strangely the difference in predictions is slightly different:
2-norm of jax pred: 0.5016137958
2-norm of pytorch pred: 0.5016138554
2-norm of diff in preds: 0.0000004255
Comparing the weight matrix...
torch vs np: 0.000000000000000000000000000000
np vs jax.np: 0.000000000000000000000000000000