From ea1fcc28a458739771f5112767f70d281511d2a2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 4 Aug 2023 20:06:38 +0200 Subject: [PATCH] [SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM (#4470) * correct * correct blocks * finish * finish * finish * Apply suggestions from code review * fix * up * up * up * Update examples/dreambooth/README_sdxl.md Co-authored-by: Sayak Paul * Apply suggestions from code review --------- Co-authored-by: Sayak Paul --- examples/controlnet/train_controlnet_sdxl.py | 1 + examples/dreambooth/README_sdxl.md | 18 +++++++ examples/dreambooth/train_dreambooth_lora.py | 5 ++ .../dreambooth/train_dreambooth_lora_sdxl.py | 9 +++- src/diffusers/models/transformer_2d.py | 33 ++++++++---- src/diffusers/models/unet_2d_blocks.py | 50 ++++++++++++++---- src/diffusers/models/unet_2d_condition.py | 6 +-- .../versatile_diffusion/modeling_text_unet.py | 52 +++++++++++++++---- tests/models/test_models_unet_2d_condition.py | 36 +++++++++++++ 9 files changed, 175 insertions(+), 35 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 6be07a3805..cf53258027 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -899,6 +899,7 @@ def main(args): if args.gradient_checkpointing: controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index c1a7b73345..65481dfcc3 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -101,6 +101,24 @@ To better track our training experiments, we're using the following flags in the Our experiments were conducted on a single 40GB A100 GPU. +### Dog toy example with < 16GB VRAM + +By making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command: + +```diff ++ --enable_xformers_memory_efficient_attention \ ++ --gradient_checkpointing \ ++ --use_8bit_adam \ ++ --mixed_precision="fp16" \ +``` + +and making sure that you have the following libraries installed: + +``` +bitsandbytes>=0.40.0 +xformers>=0.0.20 +``` + ### Inference Once training is done, we can perform inference like so: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index a7d135b1f2..72d4ab77e0 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -839,6 +839,11 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes # The sizes of the attention layers consist only of two different variables: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 423dee1568..a2b6e4a382 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -215,7 +215,7 @@ def parse_args(input_args=None): parser.add_argument( "--resolution", type=int, - default=512, + default=1024, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -644,7 +644,6 @@ def main(args): pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, - safety_checker=None, revision=args.revision, ) pipeline.set_progress_bar_config(disable=True) @@ -755,6 +754,12 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + # now we will add new LoRA weights to the attention layers # Set correct lora layers unet_lora_attn_procs = {} diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 998535c58a..344a9441ce 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -204,6 +204,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -289,15 +291,28 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 2. Blocks for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) # 3. Output if self.is_input_continuous: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 8d7e864dfc..6f3037d624 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -623,6 +623,8 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.FloatTensor, @@ -634,15 +636,45 @@ class UNetMidBlock2DCrossAttn(nn.Module): ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index cede2ed9d3..fea1b4cd78 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -36,12 +36,8 @@ from .embeddings import ( ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, get_down_block, get_up_block, ) @@ -694,7 +690,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 7a69a7908e..adb41a8dfd 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -800,7 +800,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( @@ -1784,6 +1784,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.FloatTensor, @@ -1795,15 +1797,45 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 4eeb1b926b..bd0a89fcef 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -364,6 +364,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test for module in model.children(): check_sliceable_dim_attr(module) + def test_gradient_checkpointing_is_applied(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model_class_copy = copy.copy(self.model_class) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + EXPECTED_SET = { + "CrossAttnUpBlock2D", + "CrossAttnDownBlock2D", + "UNetMidBlock2DCrossAttn", + "UpBlock2D", + "Transformer2DModel", + "DownBlock2D", + } + + assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + def test_special_attn_proc(self): class AttnEasyProc(torch.nn.Module): def __init__(self, num):