From 13d73d9303583f430763357975fcb2398c009a50 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Nov 2023 18:58:37 +0100 Subject: [PATCH] [Lora] Seperate logic (#5809) * [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better --- examples/dreambooth/train_dreambooth_lora.py | 35 ++++++- .../dreambooth/train_dreambooth_lora_sdxl.py | 35 ++++++- .../text_to_image/train_text_to_image_lora.py | 96 ++++++++++++++----- .../train_text_to_image_lora_sdxl.py | 35 ++++++- src/diffusers/loaders/__init__.py | 5 +- src/diffusers/loaders/lora.py | 37 ++++++- src/diffusers/models/lora.py | 28 ++---- 7 files changed, 218 insertions(+), 53 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9250865a3a..b82dfa38c1 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -57,7 +57,7 @@ from diffusers.models.attention_processor import ( AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, ) -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -70,6 +70,39 @@ check_min_version("0.24.0.dev0") logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 97b60c8f52..dd7b29ca88 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -50,7 +50,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ check_min_version("0.24.0.dev0") logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 78b443d149..b7309196de 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -40,8 +40,7 @@ from transformers import CLIPTextModel, CLIPTokenizer import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -54,6 +53,39 @@ check_min_version("0.24.0.dev0") logger = get_logger(__name__, log_level="INFO") +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): @@ -458,25 +490,43 @@ def main(): # => 32 layers # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) - lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank + ) ) - unet.set_attn_processor(lora_attn_procs) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=args.rank, + ) + ) + + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -491,8 +541,6 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -517,7 +565,7 @@ def main(): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - lora_layers.parameters(), + unet_lora_parameters, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -644,8 +692,8 @@ def main(): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -777,7 +825,7 @@ def main(): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = unet_lora_parameters accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index bff928541f..96bfe9e167 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -50,7 +50,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ check_min_version("0.24.0.dev0") logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6847368560..45c8c97c76 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder): deprecate( "text_encoder_load_state_dict in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) state_dict = {} @@ -34,7 +34,7 @@ if is_transformers_available(): deprecate( "text_encoder_attn_modules in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) from transformers import CLIPTextModel, CLIPTextModelWithProjection @@ -67,7 +67,6 @@ if is_torch_available(): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from ..models.lora import text_encoder_lora_state_dict from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index ab5d0ffd01..06eb3af05e 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -47,9 +47,10 @@ from ..utils import ( if is_transformers_available(): - from transformers import PreTrainedModel + from transformers import CLIPTextModel, CLIPTextModelWithProjection - from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules + # To be deprecated soon + from ..models.lora import PatchedLoraProjection if is_accelerate_available(): from accelerate import init_empty_weights @@ -66,6 +67,34 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_mlp_modules(text_encoder): + mlp_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + mlp_mod = layer.mlp + name = f"text_model.encoder.layers.{i}.mlp" + mlp_modules.append((name, mlp_mod)) + else: + raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") + + return mlp_modules + + class LoraLoaderMixin: r""" Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`]. @@ -1415,7 +1444,7 @@ class LoraLoaderMixin: ) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Disable the text encoder's LoRA layers. @@ -1445,7 +1474,7 @@ class LoraLoaderMixin: raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=False) - def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Enables the text encoder's LoRA layers. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 9edec19a3a..daac8f902c 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# IMPORTANT: # +################################################################### +# ----------------------------------------------------------------# +# This file is deprecated and will be removed soon # +# (as soon as PEFT will become a required dependency for LoRA) # +# ----------------------------------------------------------------# +################################################################### + from typing import Optional, Tuple, Union import torch @@ -57,25 +66,6 @@ def text_encoder_mlp_modules(text_encoder): return mlp_modules -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection):