Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion optimum/habana/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Import htcore here to support model quantization
import habana_frameworks.torch.core as htcore # noqa: F401

return super().from_pretrained(
# Normally we just need to return super().from_pretrained. However this is a
# workaround for Transformers 4.49.0 issue (sub_model torch_dtype option ignored).
# Note this issue is already fixed in 4.50.0dev working branch..
model = super().from_pretrained(
pretrained_model_name_or_path,
**kwargs,
)
if bf16_full_eval:
# Get the component names
component_names = [name for name in model.__dict__ if not name.startswith("_")]
# Iterate through the component names and fix dtype
for name in component_names:
component = getattr(model, name, None)
if component is not None and hasattr(component, "dtype"):
component.to(torch.bfloat16)

return model

@classmethod
def save_lora_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,27 @@ def run_unet(

return latents

# Normally we do not wrap from_pretrained. However this is a
# workaround for Transformers 4.49.0 issue (sub_model torch_dtype option ignored).
# Note this issue is already fixed in 4.50.0dev working branch..
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
bf16_full_eval = kwargs.get("torch_dtype", None) == torch.bfloat16
model = super().from_pretrained(
pretrained_model_name_or_path,
**kwargs,
)
if bf16_full_eval:
# Get the component names
component_names = [name for name in model.__dict__ if not name.startswith("_")]
# Iterate through the component names and fix dtype
for name in component_names:
component = getattr(model, name, None)
if component is not None and hasattr(component, "dtype"):
component.to(torch.bfloat16)

return model

@classmethod
def _split_inputs_into_batches(
cls,
Expand Down