|
31 | 31 | from accelerate import Accelerator |
32 | 32 | from accelerate.logging import get_logger |
33 | 33 | from accelerate.utils import ProjectConfiguration, set_seed |
34 | | -from huggingface_hub import create_repo, model_info, upload_folder |
| 34 | +from huggingface_hub import create_repo, upload_folder |
35 | 35 | from packaging import version |
36 | 36 | from PIL import Image |
37 | 37 | from torch.utils.data import Dataset |
@@ -589,16 +589,6 @@ def __getitem__(self, index): |
589 | 589 | return example |
590 | 590 |
|
591 | 591 |
|
592 | | -def model_has_vae(args): |
593 | | - config_file_name = os.path.join("vae", AutoencoderKL.config_name) |
594 | | - if os.path.isdir(args.pretrained_model_name_or_path): |
595 | | - config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) |
596 | | - return os.path.isfile(config_file_name) |
597 | | - else: |
598 | | - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings |
599 | | - return any(file.rfilename == config_file_name for file in files_in_repo) |
600 | | - |
601 | | - |
602 | 592 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): |
603 | 593 | if tokenizer_max_length is not None: |
604 | 594 | max_length = tokenizer_max_length |
@@ -753,11 +743,13 @@ def main(args): |
753 | 743 | text_encoder = text_encoder_cls.from_pretrained( |
754 | 744 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision |
755 | 745 | ) |
756 | | - if model_has_vae(args): |
| 746 | + try: |
757 | 747 | vae = AutoencoderKL.from_pretrained( |
758 | 748 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision |
759 | 749 | ) |
760 | | - else: |
| 750 | + except OSError: |
| 751 | + # IF does not have a VAE so let's just set it to None |
| 752 | + # We don't have to error out here |
761 | 753 | vae = None |
762 | 754 |
|
763 | 755 | unet = UNet2DConditionModel.from_pretrained( |
|
0 commit comments