Closed
Description
When using the following script to convert a trained flax model to pytorch, the model seems to perform extremely poorly.
from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model.save_pretrained("./")
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch
model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish", from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)
Comparing gives the following input.
tensor([[[ 1.7789, -13.5291, -11.2138, ..., -5.2875, -9.3274, -4.7912],
[ 2.3076, -13.4161, -11.1511, ..., -5.3181, -9.0602, -4.6083],
[ 2.6451, -13.4425, -11.0671, ..., -5.2838, -8.8323, -4.2280],
...,
[ 1.9009, -13.6516, -11.2348, ..., -4.9726, -9.3278, -4.6060],
[ 2.0522, -13.5394, -11.2804, ..., -4.9960, -9.1956, -4.5691],
[ 2.2570, -13.5093, -11.2640, ..., -4.9986, -9.1292, -4.3310]],
[[ 1.7789, -13.5291, -11.2138, ..., -5.2875, -9.3274, -4.7912],
[ 2.3076, -13.4161, -11.1511, ..., -5.3181, -9.0602, -4.6083],
[ 2.6451, -13.4425, -11.0671, ..., -5.2838, -8.8323, -4.2280],
...,
[ 1.9009, -13.6516, -11.2348, ..., -4.9726, -9.3278, -4.6060],
[ 2.0522, -13.5394, -11.2804, ..., -4.9960, -9.1956, -4.5691],
[ 2.2570, -13.5093, -11.2640, ..., -4.9986, -9.1292, -4.3310]]],
grad_fn=<AddBackward0>)
[[[ 0.1418128 -14.170926 -11.12649 ... -7.542998 -10.79537
-9.382975 ]
[ 1.7505689 -13.178099 -10.356588 ... -6.794136 -10.567211
-8.6670065 ]
[ 2.0270724 -13.522658 -10.372475 ... -7.0110755 -10.396935
-8.419178 ]
...
[ 0.19080782 -14.390833 -11.399942 ... -7.469897 -10.715849
-9.234054 ]
[ 1.3052869 -13.332332 -10.702984 ... -6.9498534 -10.813769
-8.608736 ]
[ 1.6442876 -13.226774 -10.59941 ... -7.0290956 -10.693554
-8.457008 ]]
[[ 0.1418128 -14.170926 -11.12649 ... -7.542998 -10.79537
-9.382975 ]
[ 1.7505689 -13.178099 -10.356588 ... -6.794136 -10.567211
-8.6670065 ]
[ 2.0270724 -13.522658 -10.372475 ... -7.0110755 -10.396935
-8.419178 ]
...
[ 0.19080782 -14.390833 -11.399942 ... -7.469897 -10.715849
-9.234054 ]
[ 1.3052869 -13.332332 -10.702984 ... -6.9498534 -10.813769
-8.608736 ]
[ 1.6442876 -13.226774 -10.59941 ... -7.0290956 -10.693554
-8.457008 ]]]
Metadata
Metadata
Assignees
Labels
No labels