1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM (#4470)

* correct

* correct blocks

* finish

* finish

* finish

* Apply suggestions from code review

* fix

* up

* up

* up

* Update examples/dreambooth/README_sdxl.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Patrick von Platen
2023-08-04 20:06:38 +02:00
committed by GitHub
parent 66de221409
commit ea1fcc28a4
9 changed files with 175 additions and 35 deletions

View File

@@ -899,6 +899,7 @@ def main(args):
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
unet.enable_gradient_checkpointing()
# Check that all trainable models are in full precision
low_precision_error_string = (

View File

@@ -101,6 +101,24 @@ To better track our training experiments, we're using the following flags in the
Our experiments were conducted on a single 40GB A100 GPU.
### Dog toy example with < 16GB VRAM
By making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command:
```diff
+ --enable_xformers_memory_efficient_attention \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --mixed_precision="fp16" \
```
and making sure that you have the following libraries installed:
```
bitsandbytes>=0.40.0
xformers>=0.0.20
```
### Inference
Once training is done, we can perform inference like so:

View File

@@ -839,6 +839,11 @@ def main(args):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:

View File

@@ -215,7 +215,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--resolution",
type=int,
default=512,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
@@ -644,7 +644,6 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
@@ -755,6 +754,12 @@ def main(args):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_attn_procs = {}

View File

@@ -204,6 +204,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
@@ -289,15 +291,28 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:

View File

@@ -623,6 +623,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -634,15 +636,45 @@ class UNetMidBlock2DCrossAttn(nn.Module):
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
return hidden_states

View File

@@ -36,12 +36,8 @@ from .embeddings import (
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
)
@@ -694,7 +690,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(

View File

@@ -800,7 +800,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
@@ -1784,6 +1784,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -1795,15 +1797,45 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
return hidden_states

View File

@@ -364,6 +364,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):