1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[HiDream LoRA] optimizations + small updates (#11381)

* 1. add pre-computation of prompt embeddings when custom prompts are used as well
2. save model card even if model is not pushed to hub
3. remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config)
4. add skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs

* pre encode validation prompt as well

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* pre encode validation prompt as well

* Apply style fixes

* empty commit

* change default trained modules

* empty commit

* address comments + change encoding of validation prompt (before it was only pre-encoded if custom prompts are provided, but should be pre-encoded either way)

* Apply style fixes

* empty commit

* fix validation_embeddings definition

* fix final inference condition

* fix pipeline deletion in last inference

* Apply style fixes

* empty commit

* layers

* remove readme remarks on only pre-computing when instance prompt is provided and change example to 3d icons

* smol fix

* empty commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Linoy Tsaban
2025-04-24 07:48:19 +03:00
committed by GitHub
parent b4be42282d
commit edd7880418
2 changed files with 158 additions and 122 deletions

View File

@@ -51,54 +51,41 @@ When running `accelerate config`, if we specify torch compile mode to True there
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
### 3d icon example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
> [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
> text embeddings are pre-computed to save memory.
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> all text embeddings are pre-computed to save memory.
```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="dog"
export INSTANCE_DIR="linoyts/3d_icon"
export OUTPUT_DIR="trained-hidream-lora"
accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--dataset_name=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--instance_prompt="3d icon" \
--caption_column="prompt"\
--validation_prompt="a 3dicon, a llama eating ramen" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--rank=16 \
--rank=8 \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=1000 \
--cache_latents \
--cache_latents\
--gradient_checkpointing \
--validation_epochs=25 \
--seed="0" \
@@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.

View File

@@ -120,11 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation.
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> scheduler = UniPCMultistepScheduler(
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
... )
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
@@ -136,7 +132,6 @@ You should use `{instance_prompt}` to trigger the image generation.
>>> pipe = HiDreamImagePipeline.from_pretrained(
... "HiDream-ai/HiDream-I1-Full",
... scheduler=scheduler,
... tokenizer_4=tokenizer_4,
... text_encoder_4=text_encoder_4,
... torch_dtype=torch.bfloat16,
@@ -201,6 +196,7 @@ def log_validation(
torch_dtype,
is_final_validation=False,
):
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
@@ -212,28 +208,16 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
pipeline_args["prompt"],
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds_t5=prompt_embeds_t5,
prompt_embeds_llama3=prompt_embeds_llama3,
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
prompt_embeds_t5=pipeline_args["prompt_embeds_t5"],
prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"],
negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"],
negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"],
pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"],
negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
@@ -252,9 +236,9 @@ def log_validation(
}
)
pipeline.to("cpu")
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
return images
@@ -392,6 +376,14 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--skip_final_inference",
default=False,
action="store_true",
help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
)
parser.add_argument(
"--final_validation_prompt",
type=str,
@@ -1016,6 +1008,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
pipeline.to("cpu")
del pipeline
free_memory()
@@ -1140,7 +1133,7 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
target_modules = ["to_k", "to_q", "to_v", "to_out"]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
@@ -1314,42 +1307,65 @@ def main(args):
)
def compute_text_embeddings(prompt, text_encoding_pipeline):
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = (
text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)
)
if args.offload: # back to cpu
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds
(
t5_prompt_embeds,
negative_prompt_embeds_t5,
llama3_prompt_embeds,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)
return (
t5_prompt_embeds,
llama3_prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds_t5,
negative_prompt_embeds_llama3,
negative_pooled_prompt_embeds,
)
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
(
class_prompt_hidden_states_t5,
class_prompt_hidden_states_llama3,
class_pooled_prompt_embeds,
) = compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Clear the memory here
if not train_dataset.custom_instance_prompts:
# delete tokenizers and text ecnoders except for llama (tokenizer & te four)
# as it's needed for inference with pipeline
del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two, tokenizer_three
if not args.validation_prompt:
del tokenizer_four, text_encoder_four
free_memory()
validation_embeddings = {}
if args.validation_prompt is not None:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1367,20 +1383,52 @@ def main(args):
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_shift_factor = vae.config.shift_factor
if args.cache_latents:
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
t5_prompt_cache = []
llama3_prompt_cache = []
pooled_prompt_cache = []
latents_cache = []
if args.offload:
vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.cache_latents:
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds, _, _, _ = compute_text_embeddings(
batch["prompts"], text_encoding_pipeline
)
t5_prompt_cache.append(t5_prompt_embeds)
llama3_prompt_cache.append(llama3_prompt_embeds)
pooled_prompt_cache.append(pooled_prompt_embeds)
if args.validation_prompt is None:
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.offload or args.cache_latents:
vae = vae.to("cpu")
if args.cache_latents:
del vae
free_memory()
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
del (
text_encoder_one,
text_encoder_two,
text_encoder_three,
text_encoder_four,
tokenizer_two,
tokenizer_three,
tokenizer_four,
text_encoding_pipeline,
)
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -1487,9 +1535,9 @@ def main(args):
with accelerator.accumulate(models_to_accumulate):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoding_pipeline
)
t5_prompt_embeds = t5_prompt_cache[step]
llama3_prompt_embeds = llama3_prompt_cache[step]
pooled_prompt_embeds = pooled_prompt_cache[step]
else:
t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1)
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1)
@@ -1619,26 +1667,30 @@ def main(args):
# create pipeline
pipeline = HiDreamImagePipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer_4=tokenizer_four,
text_encoder_4=text_encoder_four,
tokenizer=None,
text_encoder=None,
tokenizer_2=None,
text_encoder_2=None,
tokenizer_3=None,
text_encoder_3=None,
tokenizer_4=None,
text_encoder_4=None,
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
pipeline_args=validation_embeddings,
torch_dtype=weight_dtype,
epoch=epoch,
)
free_memory()
images = None
del pipeline
images = None
free_memory()
# Save the lora layers
accelerator.wait_for_everyone()
@@ -1655,50 +1707,49 @@ def main(args):
transformer_lora_layers=transformer_lora_layers,
)
# Final inference
# Load previous pipeline
tokenizer_4 = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_4_name_or_path)
tokenizer_4.pad_token = tokenizer_4.eos_token
text_encoder_4 = LlamaForCausalLM.from_pretrained(
args.pretrained_text_encoder_4_name_or_path,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipeline = HiDreamImagePipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images}
run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
should_run_final_inference = not args.skip_final_inference and run_validation
if should_run_final_inference:
# Final inference
# Load previous pipeline
pipeline = HiDreamImagePipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=None,
text_encoder=None,
tokenizer_2=None,
text_encoder_2=None,
tokenizer_3=None,
text_encoder_3=None,
tokenizer_4=None,
text_encoder_4=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
pipeline_args=validation_embeddings,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()
validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
save_model_card(
(args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt,
validation_prompt=validation_prpmpt,
validation_prompt=validation_prompt,
repo_folder=args.output_dir,
)
@@ -1711,7 +1762,6 @@ def main(args):
)
images = None
del pipeline
accelerator.end_training()