Closed
Description
Hi, I followed the causal language modeling in Flax tutorial notebook provided here in Colab. And at the end of the training, I'd like to get a working PyTorch model from the JAX/Flax weights, hence I did this:
from transformers import GPT2LMHeadModel
mdl_path = "w11wo/sundanese-gpt2-base"
pt_model = GPT2LMHeadModel.from_pretrained(mdl_path, from_flax=True)
pt_model.save_pretrained(mdl_path)
But during the conversion, it raised this error
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-44-08a8ff6e575c> in <module>()
7 tokenizer.save_pretrained(mdl_path)
8
----> 9 pt_model = GPT2LMHeadModel.from_pretrained(mdl_path, from_flax=True)
10 pt_model.save_pretrained(mdl_path)
11
2 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_pytorch_utils.py in load_flax_weights_in_pytorch_model(pt_model, flax_state)
199 # add weight to pytorch dict
200 flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
--> 201 pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
202 # remove from missing keys
203 missing_keys.remove(flax_key)
TypeError: can't convert np.ndarray of type bfloat16. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
I think this issue occurred because the model I instantiated used bfloat16
-- just as the tutorial showed. Specifically this block
from transformers import FlaxAutoModelForCausalLM
model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))
I'd like to know if there's a workaround to this problem. Thanks!
Metadata
Metadata
Assignees
Labels
No labels