1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SD 3.5 Dreambooth LoRA] support configurable training block & layers (#9762)

* configurable layers

* configurable layers

* update README

* style

* add test

* style

* add layer test, update readme, add nargs

* readme

* test style

* remove print, change nargs

* test arg change

* style

* revert nargs 2/2

* address sayaks comments

* style

* address sayaks comments
This commit is contained in:
Linoy Tsaban
2024-10-28 16:07:54 +02:00
committed by GitHub
parent 493aa74312
commit db5b6a9630
3 changed files with 143 additions and 1 deletions

View File

@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```
### Targeting Specific Blocks & Layers
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
transformer blocks (sometimes as little as two) can be enough to get great results.
In some cases, it can be even better to maintain some of the blocks/layers frozen.
For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
> [!NOTE]
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
> So, freezing other layers/targeting specific layers is a viable approach.
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
> **Photorealism**
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
> **Anatomy preservation**
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
```diff
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
```
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
- with `--lora_layers` you can specify the types of layers you wish to train.
By default, the trained layers are -
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
```diff
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
```
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
freezing some of the early & blocks is usually better.
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:

View File

@@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
transformer_block_idx = 0
layer_type = "attn.to_k"
def test_dreambooth_lora_sd3(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
@@ -136,6 +139,74 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_block(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_blocks {self.transformer_block_idx}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
# In this test, only params of transformer block 0 should be in the state dict
starts_with_transformer = all(
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layer(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_layers {self.layer_type}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""

View File

@@ -571,6 +571,25 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
),
)
parser.add_argument(
"--lora_blocks",
type=str,
default=None,
help=(
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -1222,13 +1241,31 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_out.0",
"attn.to_q",
"attn.to_v",
]
if args.lora_blocks is not None:
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
target_modules = [
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
]
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)