Skip to content

Issue converting Flax model to Pytorch #12554

Closed
@BirgerMoell

Description

@BirgerMoell

When using the following script to convert a trained flax model to pytorch, the model seems to perform extremely poorly.

from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model.save_pretrained("./")
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch 

model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish", from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)

logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)

Comparing gives the following input.

tensor([[[  1.7789, -13.5291, -11.2138,  ...,  -5.2875,  -9.3274,  -4.7912],
         [  2.3076, -13.4161, -11.1511,  ...,  -5.3181,  -9.0602,  -4.6083],
         [  2.6451, -13.4425, -11.0671,  ...,  -5.2838,  -8.8323,  -4.2280],
         ...,
         [  1.9009, -13.6516, -11.2348,  ...,  -4.9726,  -9.3278,  -4.6060],
         [  2.0522, -13.5394, -11.2804,  ...,  -4.9960,  -9.1956,  -4.5691],
         [  2.2570, -13.5093, -11.2640,  ...,  -4.9986,  -9.1292,  -4.3310]],

        [[  1.7789, -13.5291, -11.2138,  ...,  -5.2875,  -9.3274,  -4.7912],
         [  2.3076, -13.4161, -11.1511,  ...,  -5.3181,  -9.0602,  -4.6083],
         [  2.6451, -13.4425, -11.0671,  ...,  -5.2838,  -8.8323,  -4.2280],
         ...,
         [  1.9009, -13.6516, -11.2348,  ...,  -4.9726,  -9.3278,  -4.6060],
         [  2.0522, -13.5394, -11.2804,  ...,  -4.9960,  -9.1956,  -4.5691],
         [  2.2570, -13.5093, -11.2640,  ...,  -4.9986,  -9.1292,  -4.3310]]],
       grad_fn=<AddBackward0>)
[[[  0.1418128  -14.170926   -11.12649    ...  -7.542998   -10.79537
    -9.382975  ]
  [  1.7505689  -13.178099   -10.356588   ...  -6.794136   -10.567211
    -8.6670065 ]
  [  2.0270724  -13.522658   -10.372475   ...  -7.0110755  -10.396935
    -8.419178  ]
  ...
  [  0.19080782 -14.390833   -11.399942   ...  -7.469897   -10.715849
    -9.234054  ]
  [  1.3052869  -13.332332   -10.702984   ...  -6.9498534  -10.813769
    -8.608736  ]
  [  1.6442876  -13.226774   -10.59941    ...  -7.0290956  -10.693554
    -8.457008  ]]

 [[  0.1418128  -14.170926   -11.12649    ...  -7.542998   -10.79537
    -9.382975  ]
  [  1.7505689  -13.178099   -10.356588   ...  -6.794136   -10.567211
    -8.6670065 ]
  [  2.0270724  -13.522658   -10.372475   ...  -7.0110755  -10.396935
    -8.419178  ]
  ...
  [  0.19080782 -14.390833   -11.399942   ...  -7.469897   -10.715849
    -9.234054  ]
  [  1.3052869  -13.332332   -10.702984   ...  -6.9498534  -10.813769
    -8.608736  ]
  [  1.6442876  -13.226774   -10.59941    ...  -7.0290956  -10.693554
    -8.457008  ]]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions