mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into cogvideox-lora-and-training
This commit is contained in:
@@ -221,8 +221,12 @@ Instead, only a subset of these activations (the checkpoints) are stored and the
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
Make sure to install `bitsandbytes` if you want to do so.
|
||||
### latent caching
|
||||
### Latent caching
|
||||
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
|
||||
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
|
||||
to enable `latent_caching` simply pass `--cache_latents`.
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
## Other notes
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
@@ -103,6 +103,39 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
@@ -600,6 +600,12 @@ def parse_args(input_args=None):
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_latents",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Cache the VAE latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
@@ -620,6 +626,15 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_before_saving",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
|
||||
"Defaults to precision dtype used for training to save memory"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
@@ -1422,12 +1437,7 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1457,6 +1467,21 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
vae_config_shift_factor = vae.config.shift_factor
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
vae_config_block_out_channels = vae.config.block_out_channels
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
clear_objs_and_retain_memory([vae])
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1579,7 +1604,6 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder_one])
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -1613,11 +1637,15 @@ def main(args):
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
else:
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
@@ -1789,15 +1817,16 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
transformer = transformer.to(torch.float32)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
|
||||
Reference in New Issue
Block a user