diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 06b6e9edab..137f3222f6 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -864,7 +864,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): # Optimizer creation supported_optimizers = ["adam", "adamw", "prodigy"] - if args.optimizer not in ["adam", "adamw", "prodigy"]: + if args.optimizer not in supported_optimizers: logger.warning( f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" ) @@ -1463,7 +1463,6 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - # transformer = transformer.to(torch.float32) dtype = ( torch.float16 if args.mixed_precision == "fp16" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f002538362..4747d1717e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2278,14 +2278,11 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): class CogVideoXLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`CogVideoXTransformer3DModel`], - [`T5EncoderModel`](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel). Specific to - [`CogVideoX`]. + Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`]. """ - _lora_loadable_modules = ["transformer", "text_encoder"] + _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args @@ -2419,18 +2416,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): _pipeline=self, ) - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - ) - @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): @@ -2511,133 +2496,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): # Unsafe code /> @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - keys = list(state_dict.keys()) - prefix = cls.text_encoder_name if prefix is None else prefix - - # Safe prefix to check with. - if any(cls.text_encoder_name in key for key in keys): - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [ - k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix - ] - network_alphas = { - k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights by removing text encoder related changes def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -2651,9 +2514,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -2667,15 +2527,12 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): """ state_dict = {} - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - # Save the model cls.write_lora_layers( state_dict=state_dict,