diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 45c866c1aa..1b8afdce2f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -623,6 +623,81 @@ class UNet2DConditionLoadersMixin: module._unfuse_lora() +def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path in pretrained_model_name_or_paths: + if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)): + # 3.1. Load textual inversion file + model_file = None + + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path + + state_dicts.append(state_dict) + + return state_dicts + + class TextualInversionLoaderMixin: r""" Load textual inversion tokens and embeddings to the tokenizer and text encoder. @@ -685,6 +760,97 @@ class TextualInversionLoaderMixin: return prompt + def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens): + if tokenizer is None: + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if text_encoder is None: + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if len(pretrained_model_name_or_paths) != len(tokens): + raise ValueError( + f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} " + f"Make sure both lists have the same length." + ) + + valid_tokens = [t for t in tokens if t is not None] + if len(set(valid_tokens)) < len(valid_tokens): + raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") + + @staticmethod + def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer): + all_tokens = [] + all_embeddings = [] + for state_dict, token in zip(state_dicts, tokens): + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + loaded_token = token + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + else: + raise ValueError( + f"Loaded state dictonary is incorrect: {state_dict}. \n\n" + "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`" + " input key." + ) + + if token is not None and loaded_token != token: + logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + if token in tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + + all_tokens.append(token) + all_embeddings.append(embedding) + + return all_tokens, all_embeddings + + @staticmethod + def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer): + all_tokens = [] + all_embeddings = [] + + for embedding, token in zip(embeddings, tokens): + if f"{token}_1" in tokenizer.get_vocab(): + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + if is_multi_vector: + all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + all_embeddings += [e for e in embedding] # noqa: C416 + else: + all_tokens += [token] + all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding] + + return all_tokens, all_embeddings + def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], @@ -790,25 +956,44 @@ class TextualInversionLoaderMixin: ``` """ + # 1. Set correct tokenizer and text encoder tokenizer = tokenizer or getattr(self, "tokenizer", None) text_encoder = text_encoder or getattr(self, "text_encoder", None) - if tokenizer is None: + # 2. Normalize inputs + pretrained_model_name_or_paths = ( + [pretrained_model_name_or_path] + if not isinstance(pretrained_model_name_or_path, list) + else pretrained_model_name_or_path + ) + tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token + + # 3. Check inputs + self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens) + + # 4. Load state dicts of textual embeddings + state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) + + # 4. Retrieve tokens and embeddings + tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer) + + # 5. Extend tokens and embeddings for multi vector + tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer) + + # 6. Make sure all embeddings have the correct size + expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1] + if any(expected_emb_dim != emb.shape[-1] for emb in embeddings): raise ValueError( - f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" - f" `{self.load_textual_inversion.__name__}`" + "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding " + "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} " ) - if text_encoder is None: - raise ValueError( - f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling" - f" `{self.load_textual_inversion.__name__}`" - ) + # 7. Now we can be sure that loading the embedding matrix works + # < Unsafe code: - # Remove any existing hooks. + # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again is_model_cpu_offload = False is_sequential_cpu_offload = False - recursive = False for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): @@ -817,168 +1002,34 @@ class TextualInversionLoaderMixin: logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + # 7.2 save expected device and dtype + device = text_encoder.device + dtype = text_encoder.dtype - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "text_inversion", - "framework": "pytorch", - } - - if not isinstance(pretrained_model_name_or_path, list): - pretrained_model_name_or_paths = [pretrained_model_name_or_path] - else: - pretrained_model_name_or_paths = pretrained_model_name_or_path - - if isinstance(token, str): - tokens = [token] - elif token is None: - tokens = [None] * len(pretrained_model_name_or_paths) - else: - tokens = token - - if len(pretrained_model_name_or_paths) != len(tokens): - raise ValueError( - f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" - f"Make sure both lists have the same length." - ) - - valid_tokens = [t for t in tokens if t is not None] - if len(set(valid_tokens)) < len(valid_tokens): - raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") - - token_ids_and_embeddings = [] - - for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): - if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)): - # 1. Load textual inversion file - model_file = None - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except Exception as e: - if not allow_pickle: - raise e - - model_file = None - - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path - - # 2. Load token and embedding correcly from file - loaded_token = None - if isinstance(state_dict, torch.Tensor): - if token is None: - raise ValueError( - "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." - ) - embedding = state_dict - elif len(state_dict) == 1: - # diffusers - loaded_token, embedding = next(iter(state_dict.items())) - elif "string_to_param" in state_dict: - # A1111 - loaded_token = state_dict["name"] - embedding = state_dict["string_to_param"]["*"] - - if token is not None and loaded_token != token: - logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") - else: - token = loaded_token - - embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device) - - # 3. Make sure we don't mess up the tokenizer or text encoder - vocab = tokenizer.get_vocab() - if token in vocab: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." - ) - elif f"{token}_1" in vocab: - multi_vector_tokens = [token] - i = 1 - while f"{token}_{i}" in tokenizer.added_tokens_encoder: - multi_vector_tokens.append(f"{token}_{i}") - i += 1 - - raise ValueError( - f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." - ) - - is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 - - if is_multi_vector: - tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] - embeddings = [e for e in embedding] # noqa: C416 - else: - tokens = [token] - embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] + # 7.3 Increase token embedding matrix + text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens)) + input_embeddings = text_encoder.get_input_embeddings().weight + # 7.4 Load token and embedding + for token, embedding in zip(tokens, embeddings): # add tokens and get ids - tokenizer.add_tokens(tokens) - token_ids = tokenizer.convert_tokens_to_ids(tokens) - token_ids_and_embeddings += zip(token_ids, embeddings) - + tokenizer.add_tokens(token) + token_id = tokenizer.convert_tokens_to_ids(token) + input_embeddings.data[token_id] = embedding logger.info(f"Loaded textual inversion embedding for {token}.") - # resize token embeddings and set all new embeddings - text_encoder.resize_token_embeddings(len(tokenizer)) - for token_id, embedding in token_ids_and_embeddings: - text_encoder.get_input_embeddings().weight.data[token_id] = embedding + input_embeddings.to(dtype=dtype, device=device) - # offload back + # 7.5 Offload the model again if is_model_cpu_offload: self.enable_model_cpu_offload() elif is_sequential_cpu_offload: self.enable_sequential_cpu_offload() + # / Unsafe Code > + class LoraLoaderMixin: r""" @@ -2598,7 +2649,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): is_model_cpu_offload = False is_sequential_cpu_offload = False - recursive = False for _, component in self.components.items(): if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): @@ -2607,8 +2657,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config,