Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
architecture = read_field(reader, "general.architecture")[0]
model_name = read_field(reader, "general.name")

updated_architecture = None
# in llama.cpp mistral models use the same architecture as llama. We need
# to add this patch to ensure things work correctly on our side.
if "llama" in architecture and "mistral" in model_name:
Expand All @@ -377,6 +378,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
# It needs to be developed for supporting legacy t5.
elif "t5" in architecture or "t5encoder" in architecture:
parsed_parameters["config"]["is_gated_act"] = True
if "t5encoder" in architecture:
parsed_parameters["config"]["architectures"] = ["T5EncoderModel"]
updated_architecture = "t5"
else:
updated_architecture = architecture
Expand All @@ -395,7 +398,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
parsed_parameters["config"]["use_qkv_bias"] = qkv_bias
parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual

if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
if architecture not in GGUF_SUPPORTED_ARCHITECTURES and updated_architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.")

# Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4224,10 +4224,13 @@ def from_pretrained(
token=token,
revision=revision,
subfolder=subfolder,
gguf_file=gguf_file,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
if "gguf_file" in model_kwargs:
model_kwargs.pop("gguf_file")
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
Expand Down