Skip to content

[Flax] Error converting model to PyTorch from Flax #12545

Closed
@w11wo

Description

@w11wo

Hi, I followed the causal language modeling in Flax tutorial notebook provided here in Colab. And at the end of the training, I'd like to get a working PyTorch model from the JAX/Flax weights, hence I did this:

from transformers import GPT2LMHeadModel

mdl_path = "w11wo/sundanese-gpt2-base"

pt_model = GPT2LMHeadModel.from_pretrained(mdl_path, from_flax=True)
pt_model.save_pretrained(mdl_path)

But during the conversion, it raised this error

/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
  pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-44-08a8ff6e575c> in <module>()
      7 tokenizer.save_pretrained(mdl_path)
      8 
----> 9 pt_model = GPT2LMHeadModel.from_pretrained(mdl_path, from_flax=True)
     10 pt_model.save_pretrained(mdl_path)
     11 

2 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_pytorch_utils.py in load_flax_weights_in_pytorch_model(pt_model, flax_state)
    199                 # add weight to pytorch dict
    200                 flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
--> 201                 pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
    202                 # remove from missing keys
    203                 missing_keys.remove(flax_key)

TypeError: can't convert np.ndarray of type bfloat16. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

I think this issue occurred because the model I instantiated used bfloat16 -- just as the tutorial showed. Specifically this block

from transformers import FlaxAutoModelForCausalLM

model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

I'd like to know if there's a workaround to this problem. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions