@@ -436,7 +436,10 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
436436 return prompt
437437
438438 def load_textual_inversion (
439- self , pretrained_model_name_or_path : Union [str , Dict [str , torch .Tensor ]], token : Optional [str ] = None , ** kwargs
439+ self ,
440+ pretrained_model_name_or_path : Union [str , List [str ]],
441+ token : Optional [Union [str , List [str ]]] = None ,
442+ ** kwargs ,
440443 ):
441444 r"""
442445 Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
@@ -449,14 +452,20 @@ def load_textual_inversion(
449452 </Tip>
450453
451454 Parameters:
452- pretrained_model_name_or_path (`str` or `os.PathLike`):
455+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` ):
453456 Can be either:
454457
455458 - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
456459 Valid model ids should have an organization name, like
457460 `"sd-concepts-library/low-poly-hd-logos-icons"`.
458461 - A path to a *directory* containing textual inversion weights, e.g.
459462 `./my_text_inversion_directory/`.
463+ - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
464+
465+ Or a list of those elements.
466+ token (`str` or `List[str]`, *optional*):
467+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
468+ list, then `token` must also be a list of equal length.
460469 weight_name (`str`, *optional*):
461470 Name of a custom weight file. This should be used in two cases:
462471
@@ -576,16 +585,62 @@ def load_textual_inversion(
576585 "framework" : "pytorch" ,
577586 }
578587
579- # 1. Load textual inversion file
580- model_file = None
581- # Let's first try to load .safetensors weights
582- if (use_safetensors and weight_name is None ) or (
583- weight_name is not None and weight_name .endswith (".safetensors" )
584- ):
585- try :
588+ if isinstance (pretrained_model_name_or_path , str ):
589+ pretrained_model_name_or_paths = [pretrained_model_name_or_path ]
590+ else :
591+ pretrained_model_name_or_paths = pretrained_model_name_or_path
592+
593+ if isinstance (token , str ):
594+ tokens = [token ]
595+ elif token is None :
596+ tokens = [None ] * len (pretrained_model_name_or_paths )
597+ else :
598+ tokens = token
599+
600+ if len (pretrained_model_name_or_paths ) != len (tokens ):
601+ raise ValueError (
602+ f"You have passed a list of models of length { len (pretrained_model_name_or_paths )} , and list of tokens of length { len (tokens )} "
603+ f"Make sure both lists have the same length."
604+ )
605+
606+ valid_tokens = [t for t in tokens if t is not None ]
607+ if len (set (valid_tokens )) < len (valid_tokens ):
608+ raise ValueError (f"You have passed a list of tokens that contains duplicates: { tokens } " )
609+
610+ token_ids_and_embeddings = []
611+
612+ for pretrained_model_name_or_path , token in zip (pretrained_model_name_or_paths , tokens ):
613+ # 1. Load textual inversion file
614+ model_file = None
615+ # Let's first try to load .safetensors weights
616+ if (use_safetensors and weight_name is None ) or (
617+ weight_name is not None and weight_name .endswith (".safetensors" )
618+ ):
619+ try :
620+ model_file = _get_model_file (
621+ pretrained_model_name_or_path ,
622+ weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
623+ cache_dir = cache_dir ,
624+ force_download = force_download ,
625+ resume_download = resume_download ,
626+ proxies = proxies ,
627+ local_files_only = local_files_only ,
628+ use_auth_token = use_auth_token ,
629+ revision = revision ,
630+ subfolder = subfolder ,
631+ user_agent = user_agent ,
632+ )
633+ state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
634+ except Exception as e :
635+ if not allow_pickle :
636+ raise e
637+
638+ model_file = None
639+
640+ if model_file is None :
586641 model_file = _get_model_file (
587642 pretrained_model_name_or_path ,
588- weights_name = weight_name or TEXT_INVERSION_NAME_SAFE ,
643+ weights_name = weight_name or TEXT_INVERSION_NAME ,
589644 cache_dir = cache_dir ,
590645 force_download = force_download ,
591646 resume_download = resume_download ,
@@ -596,88 +651,68 @@ def load_textual_inversion(
596651 subfolder = subfolder ,
597652 user_agent = user_agent ,
598653 )
599- state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
600- except Exception as e :
601- if not allow_pickle :
602- raise e
654+ state_dict = torch .load (model_file , map_location = "cpu" )
603655
604- model_file = None
656+ # 2. Load token and embedding correcly from file
657+ if isinstance (state_dict , torch .Tensor ):
658+ if token is None :
659+ raise ValueError (
660+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
661+ )
662+ embedding = state_dict
663+ elif len (state_dict ) == 1 :
664+ # diffusers
665+ loaded_token , embedding = next (iter (state_dict .items ()))
666+ elif "string_to_param" in state_dict :
667+ # A1111
668+ loaded_token = state_dict ["name" ]
669+ embedding = state_dict ["string_to_param" ]["*" ]
670+
671+ if token is not None and loaded_token != token :
672+ logger .info (f"The loaded token: { loaded_token } is overwritten by the passed token { token } ." )
673+ else :
674+ token = loaded_token
605675
606- if model_file is None :
607- model_file = _get_model_file (
608- pretrained_model_name_or_path ,
609- weights_name = weight_name or TEXT_INVERSION_NAME ,
610- cache_dir = cache_dir ,
611- force_download = force_download ,
612- resume_download = resume_download ,
613- proxies = proxies ,
614- local_files_only = local_files_only ,
615- use_auth_token = use_auth_token ,
616- revision = revision ,
617- subfolder = subfolder ,
618- user_agent = user_agent ,
619- )
620- state_dict = torch .load (model_file , map_location = "cpu" )
676+ embedding = embedding .to (dtype = self .text_encoder .dtype , device = self .text_encoder .device )
621677
622- # 2. Load token and embedding correcly from file
623- if isinstance ( state_dict , torch . Tensor ):
624- if token is None :
678+ # 3. Make sure we don't mess up the tokenizer or text encoder
679+ vocab = self . tokenizer . get_vocab ()
680+ if token in vocab :
625681 raise ValueError (
626- "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...` ."
682+ f"Token { token } already in tokenizer vocabulary. Please choose a different token name or remove { token } and embedding from the tokenizer and text encoder ."
627683 )
628- embedding = state_dict
629- elif len (state_dict ) == 1 :
630- # diffusers
631- loaded_token , embedding = next (iter (state_dict .items ()))
632- elif "string_to_param" in state_dict :
633- # A1111
634- loaded_token = state_dict ["name" ]
635- embedding = state_dict ["string_to_param" ]["*" ]
636-
637- if token is not None and loaded_token != token :
638- logger .warn (f"The loaded token: { loaded_token } is overwritten by the passed token { token } ." )
639- else :
640- token = loaded_token
641-
642- embedding = embedding .to (dtype = self .text_encoder .dtype , device = self .text_encoder .device )
684+ elif f"{ token } _1" in vocab :
685+ multi_vector_tokens = [token ]
686+ i = 1
687+ while f"{ token } _{ i } " in self .tokenizer .added_tokens_encoder :
688+ multi_vector_tokens .append (f"{ token } _{ i } " )
689+ i += 1
643690
644- # 3. Make sure we don't mess up the tokenizer or text encoder
645- vocab = self .tokenizer .get_vocab ()
646- if token in vocab :
647- raise ValueError (
648- f"Token { token } already in tokenizer vocabulary. Please choose a different token name or remove { token } and embedding from the tokenizer and text encoder."
649- )
650- elif f"{ token } _1" in vocab :
651- multi_vector_tokens = [token ]
652- i = 1
653- while f"{ token } _{ i } " in self .tokenizer .added_tokens_encoder :
654- multi_vector_tokens .append (f"{ token } _{ i } " )
655- i += 1
691+ raise ValueError (
692+ f"Multi-vector Token { multi_vector_tokens } already in tokenizer vocabulary. Please choose a different token name or remove the { multi_vector_tokens } and embedding from the tokenizer and text encoder."
693+ )
656694
657- raise ValueError (
658- f"Multi-vector Token { multi_vector_tokens } already in tokenizer vocabulary. Please choose a different token name or remove the { multi_vector_tokens } and embedding from the tokenizer and text encoder."
659- )
695+ is_multi_vector = len (embedding .shape ) > 1 and embedding .shape [0 ] > 1
660696
661- is_multi_vector = len (embedding .shape ) > 1 and embedding .shape [0 ] > 1
697+ if is_multi_vector :
698+ tokens = [token ] + [f"{ token } _{ i } " for i in range (1 , embedding .shape [0 ])]
699+ embeddings = [e for e in embedding ] # noqa: C416
700+ else :
701+ tokens = [token ]
702+ embeddings = [embedding [0 ]] if len (embedding .shape ) > 1 else [embedding ]
662703
663- if is_multi_vector :
664- tokens = [token ] + [f"{ token } _{ i } " for i in range (1 , embedding .shape [0 ])]
665- embeddings = [e for e in embedding ] # noqa: C416
666- else :
667- tokens = [token ]
668- embeddings = [embedding [0 ]] if len (embedding .shape ) > 1 else [embedding ]
704+ # add tokens and get ids
705+ self .tokenizer .add_tokens (tokens )
706+ token_ids = self .tokenizer .convert_tokens_to_ids (tokens )
707+ token_ids_and_embeddings += zip (token_ids , embeddings )
669708
670- # add tokens and get ids
671- self .tokenizer .add_tokens (tokens )
672- token_ids = self .tokenizer .convert_tokens_to_ids (tokens )
709+ logger .info (f"Loaded textual inversion embedding for { token } ." )
673710
674- # resize token embeddings and set new embeddings
711+ # resize token embeddings and set all new embeddings
675712 self .text_encoder .resize_token_embeddings (len (self .tokenizer ))
676- for token_id , embedding in zip ( token_ids , embeddings ) :
713+ for token_id , embedding in token_ids_and_embeddings :
677714 self .text_encoder .get_input_embeddings ().weight .data [token_id ] = embedding
678715
679- logger .info (f"Loaded textual inversion embedding for { token } ." )
680-
681716
682717class LoraLoaderMixin :
683718 r"""
0 commit comments