diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 52c140a678..16eabb0077 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import os import re import warnings @@ -27,6 +26,7 @@ import torch from huggingface_hub import hf_hub_download, model_info from torch import nn +from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -46,7 +46,6 @@ if is_transformers_available(): if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module): self.w_down = None def forward(self, input): - # print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: @@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin: use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. @@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin: subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning network_alphas = kwargs.pop("network_alphas", None) @@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin: "framework": "pytorch", } + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights @@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin: # correct keys state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) + if network_alphas is not None: + network_alphas_keys = list(network_alphas.keys()) + used_network_alphas_keys = set() + lora_grouped_dict = defaultdict(dict) mapped_network_alphas = {} @@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin: # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: - network_alphas_ = copy.deepcopy(network_alphas) - for k in network_alphas_: + for k in network_alphas_keys: if k.replace(".alpha", "") in key: - mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)}) + mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)}) + used_network_alphas_keys.add(k) if not is_network_alphas_none: - if len(network_alphas) > 0: + if len(set(network_alphas_keys) - used_network_alphas_keys) > 0: raise ValueError( f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" ) @@ -411,29 +428,38 @@ class UNet2DConditionLoadersMixin: out_features = attn_processor.out_channels kernel_size = attn_processor.kernel_size - lora = LoRAConv2dLayer( - in_features=in_features, - out_features=out_features, - rank=rank, - kernel_size=kernel_size, - stride=attn_processor.stride, - padding=attn_processor.padding, - network_alpha=mapped_network_alphas.get(key), - ) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + lora = LoRAConv2dLayer( + in_features=in_features, + out_features=out_features, + rank=rank, + kernel_size=kernel_size, + stride=attn_processor.stride, + padding=attn_processor.padding, + network_alpha=mapped_network_alphas.get(key), + ) elif isinstance(attn_processor, LoRACompatibleLinear): - lora = LoRALinearLayer( - attn_processor.in_features, - attn_processor.out_features, - rank, - mapped_network_alphas.get(key), - ) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + lora = LoRALinearLayer( + attn_processor.in_features, + attn_processor.out_features, + rank, + mapped_network_alphas.get(key), + ) else: raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} - lora.load_state_dict(value_dict) lora_layers_list.append((attn_processor, lora)) + if low_cpu_mem_usage: + device = next(iter(value_dict.values())).device + dtype = next(iter(value_dict.values())).dtype + load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) + else: + lora.load_state_dict(value_dict) elif is_custom_diffusion: attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict) @@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin: f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." ) - # set correct dtype & device - lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list] - # set lora layers for target_module, lora_layer in lora_layers_list: target_module.set_lora_layer(lora_layer) + self.to(dtype=self.dtype, device=self.device) + def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): is_new_lora_format = all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() @@ -999,13 +1024,18 @@ class LoraLoaderMixin: recurive = is_sequential_cpu_offload remove_hook_from_module(component, recurse=recurive) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + self.load_lora_into_unet( + state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage + ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, + low_cpu_mem_usage=low_cpu_mem_usage, ) # Offload back. @@ -1065,6 +1095,11 @@ class LoraLoaderMixin: allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more @@ -1305,7 +1340,7 @@ class LoraLoaderMixin: return new_state_dict @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -1318,7 +1353,13 @@ class LoraLoaderMixin: See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. """ + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1343,11 +1384,12 @@ class LoraLoaderMixin: warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) - # load loras into unet - unet.load_attn_procs(state_dict, network_alphas=network_alphas) + unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage) @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): + def load_lora_into_text_encoder( + cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None + ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1364,7 +1406,13 @@ class LoraLoaderMixin: lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. """ + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as @@ -1447,6 +1495,7 @@ class LoraLoaderMixin: network_alphas, rank=rank, patch_mlp=patch_mlp, + low_cpu_mem_usage=low_cpu_mem_usage, ) # set correct dtype & device @@ -1454,12 +1503,23 @@ class LoraLoaderMixin: k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) for k, v in text_encoder_lora_state_dict.items() } - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - if len(load_state_dict_results.unexpected_keys) != 0: + if low_cpu_mem_usage: + device = next(iter(text_encoder_lora_state_dict.values())).device + dtype = next(iter(text_encoder_lora_state_dict.values())).dtype + unexpected_keys = load_model_dict_into_meta( + text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype + ) + else: + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + unexpected_keys = load_state_dict_results.unexpected_keys + + if len(unexpected_keys) != 0: raise ValueError( f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" ) + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. @@ -1492,11 +1552,21 @@ class LoraLoaderMixin: rank: Union[Dict[str, int], int] = 4, dtype=None, patch_mlp=False, + low_cpu_mem_usage=False, ): r""" Monkey-patches the forward passes of attention modules of the text encoder. """ + def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): + linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) + + lora_parameters.extend(model.lora_linear_layer.parameters()) + return model + # First, remove any monkey-patch that might have been applied before cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) @@ -1515,45 +1585,18 @@ class LoraLoaderMixin: else: current_rank = rank - q_linear_layer = ( - attn_module.q_proj.regular_linear_layer - if isinstance(attn_module.q_proj, PatchedLoraProjection) - else attn_module.q_proj + attn_module.q_proj = create_patched_linear_lora( + attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters ) - attn_module.q_proj = PatchedLoraProjection( - q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype + attn_module.k_proj = create_patched_linear_lora( + attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters ) - lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) - - k_linear_layer = ( - attn_module.k_proj.regular_linear_layer - if isinstance(attn_module.k_proj, PatchedLoraProjection) - else attn_module.k_proj + attn_module.v_proj = create_patched_linear_lora( + attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters ) - attn_module.k_proj = PatchedLoraProjection( - k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype + attn_module.out_proj = create_patched_linear_lora( + attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters ) - lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) - - v_linear_layer = ( - attn_module.v_proj.regular_linear_layer - if isinstance(attn_module.v_proj, PatchedLoraProjection) - else attn_module.v_proj - ) - attn_module.v_proj = PatchedLoraProjection( - v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) - - out_linear_layer = ( - attn_module.out_proj.regular_linear_layer - if isinstance(attn_module.out_proj, PatchedLoraProjection) - else attn_module.out_proj - ) - attn_module.out_proj = PatchedLoraProjection( - out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) if patch_mlp: for name, mlp_module in text_encoder_mlp_modules(text_encoder): @@ -1563,25 +1606,12 @@ class LoraLoaderMixin: current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") - fc1_linear_layer = ( - mlp_module.fc1.regular_linear_layer - if isinstance(mlp_module.fc1, PatchedLoraProjection) - else mlp_module.fc1 + mlp_module.fc1 = create_patched_linear_lora( + mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters ) - mlp_module.fc1 = PatchedLoraProjection( - fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype + mlp_module.fc2 = create_patched_linear_lora( + mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters ) - lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) - - fc2_linear_layer = ( - mlp_module.fc2.regular_linear_layer - if isinstance(mlp_module.fc2, PatchedLoraProjection) - else mlp_module.fc2 - ) - mlp_module.fc2 = PatchedLoraProjection( - fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype - ) - lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) if is_network_alphas_populated and len(network_alphas) > 0: raise ValueError( @@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin: vae = AutoencoderKL(**vae_config) if is_accelerate_available(): - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, "cpu", value=param) + load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu") else: vae.load_state_dict(converted_vae_checkpoint) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e53fa7e528..67746ebace 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ) +def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None): + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + unexpected_keys = [] + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + + if empty_state_dict[param_name].shape != param.shape: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys + + def _load_state_dict_into_model(model_to_load, state_dict): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it @@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) - unexpected_keys = [] - empty_state_dict = model.state_dict() - for param_name, param in state_dict.items(): - accepts_dtype = "dtype" in set( - inspect.signature(set_module_tensor_to_device).parameters.keys() - ) - - if param_name not in empty_state_dict: - unexpected_keys.append(param_name) - continue - - if empty_state_dict[param_name].shape != param.shape: - raise ValueError( - f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." - ) - - if accepts_dtype: - set_module_tensor_to_device( - model, param_name, param_device, value=param, dtype=torch_dtype - ) - else: - set_module_tensor_to_device(model, param_name, param_device, value=param) + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + ) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: