mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SD 3.5 Dreambooth LoRA] support configurable training block & layers (#9762)
* configurable layers * configurable layers * update README * style * add test * style * add layer test, update readme, add nargs * readme * test style * remove print, change nargs * test arg change * style * revert nargs 2/2 * address sayaks comments * style * address sayaks comments
This commit is contained in:
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Targeting Specific Blocks & Layers
|
||||
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
|
||||
transformer blocks (sometimes as little as two) can be enough to get great results.
|
||||
In some cases, it can be even better to maintain some of the blocks/layers frozen.
|
||||
|
||||
For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
|
||||
> [!NOTE]
|
||||
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
|
||||
> So, freezing other layers/targeting specific layers is a viable approach.
|
||||
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
|
||||
> **Photorealism**
|
||||
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
|
||||
> **Anatomy preservation**
|
||||
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
|
||||
|
||||
|
||||
We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
|
||||
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
|
||||
```diff
|
||||
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
|
||||
```
|
||||
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
|
||||
- with `--lora_layers` you can specify the types of layers you wish to train.
|
||||
By default, the trained layers are -
|
||||
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
|
||||
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
|
||||
```diff
|
||||
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
|
||||
```
|
||||
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
|
||||
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
|
||||
freezing some of the early & blocks is usually better.
|
||||
|
||||
|
||||
### Text Encoder Training
|
||||
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
|
||||
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
|
||||
|
||||
@@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
|
||||
|
||||
transformer_block_idx = 0
|
||||
layer_type = "attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
@@ -136,6 +139,74 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_block(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_blocks {self.transformer_block_idx}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
# In this test, only params of transformer block 0 should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layer(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_layers {self.layer_type}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
|
||||
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -571,6 +571,25 @@ def parse_args(input_args=None):
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
|
||||
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_blocks",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
|
||||
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -1222,13 +1241,31 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = [
|
||||
"attn.add_k_proj",
|
||||
"attn.add_q_proj",
|
||||
"attn.add_v_proj",
|
||||
"attn.to_add_out",
|
||||
"attn.to_k",
|
||||
"attn.to_out.0",
|
||||
"attn.to_q",
|
||||
"attn.to_v",
|
||||
]
|
||||
if args.lora_blocks is not None:
|
||||
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
|
||||
target_modules = [
|
||||
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
target_modules=target_modules,
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user