-
Notifications
You must be signed in to change notification settings - Fork 28.8k
Issue converting Flax model to Pytorch #12554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Running the following command: from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch
model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish", from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx) should give more or less identical results |
Just corrected the pt weights. If you run: from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch
model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx) You should see equal results. The checkpoint was somehow incorrectly converted. |
Note that one should convert checkpoints with: from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained("...", from_flax=True)
model.save_pretrained("./") and not the Also it's important to realize that the lm head layer is actually tied to the input word embedding layer which is why Flax just doesn't save those weights. Then when converting those weights to PyTorch, PyTorch says there are missing but since the weights are tied PyTorch would have overwritten those weights anyways with the input embeddings which is why it the warning:
doesn't matter. |
@BirgerMoell Also note that your local |
Awesome. Just to clarify. Once I'm done with training, this script should help me convert the model to pytorch.
|
@patrickvonplaten the uploaded model is still performing poorly so I'm not 100% the issue is fully resolved. |
Hi @BirgerMoell, I'm training a RoBERTa model too using JAX during this community week -- model here. I got about 2.188 evaluation loss, yet the results are still somewhat jibberish despite the result. I think our models are, somehow, trained incorrectly? Or possibly require more data cleaning of some sort. |
@w11wo Yeah. Something is definitely up. I think a good idea would be that people who work with similar models figure out a good way to clean the data and look at other things that might be wrong. |
Facing same issue here, trained a model with Flax / Jax, then saved. When loading in Pytorch via "from Flax = True" , I have silly output despite training showing OK loss... Did you manage to find a solution or understand the issue ? |
Hi @jppaolim ! In my case, I loaded the earlier weights of the model (from the first few epochs), instead of the fully-trained model weights from the last training epoch. Loading the right model weights fixed it for me. Another way to fix it might be training for longer. Hope this helps! :) |
When using the following script to convert a trained flax model to pytorch, the model seems to perform extremely poorly.
Comparing gives the following input.
The text was updated successfully, but these errors were encountered: