diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index defa5a3573..a0e8c1feca 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -51,54 +51,41 @@ When running `accelerate config`, if we specify torch compile mode to True there Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment. -### Dog toy example +### 3d icon example -Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. - -Let's first download it locally: - -```python -from huggingface_hub import snapshot_download - -local_dir = "./dog" -snapshot_download( - "diffusers/dog-example", - local_dir=local_dir, repo_type="dataset", - ignore_patterns=".gitattributes", -) -``` +For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon. This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. Now, we can launch training using: > [!NOTE] > The following training configuration prioritizes lower memory consumption by using gradient checkpointing, -> 8-bit Adam optimizer, latent caching, offloading, no validation. -> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image) -> text embeddings are pre-computed to save memory. - +> 8-bit Adam optimizer, latent caching, offloading, no validation. +> all text embeddings are pre-computed to save memory. ```bash export MODEL_NAME="HiDream-ai/HiDream-I1-Dev" -export INSTANCE_DIR="dog" +export INSTANCE_DIR="linoyts/3d_icon" export OUTPUT_DIR="trained-hidream-lora" accelerate launch train_dreambooth_lora_hidream.py \ --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ + --dataset_name=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --mixed_precision="bf16" \ - --instance_prompt="a photo of sks dog" \ + --instance_prompt="3d icon" \ + --caption_column="prompt"\ + --validation_prompt="a 3dicon, a llama eating ramen" \ --resolution=1024 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --use_8bit_adam \ - --rank=16 \ + --rank=8 \ --learning_rate=2e-4 \ --report_to="wandb" \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ --max_train_steps=1000 \ - --cache_latents \ + --cache_latents\ --gradient_checkpointing \ --validation_epochs=25 \ --seed="0" \ @@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization: * `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. -* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done. Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 72e458a72a..26a920906b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -120,11 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation. ```py >>> import torch >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM - >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline - - >>> scheduler = UniPCMultistepScheduler( - ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True - ... ) + >>> from diffusers import HiDreamImagePipeline >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( @@ -136,7 +132,6 @@ You should use `{instance_prompt}` to trigger the image generation. >>> pipe = HiDreamImagePipeline.from_pretrained( ... "HiDream-ai/HiDream-I1-Full", - ... scheduler=scheduler, ... tokenizer_4=tokenizer_4, ... text_encoder_4=text_encoder_4, ... torch_dtype=torch.bfloat16, @@ -201,6 +196,7 @@ def log_validation( torch_dtype, is_final_validation=False, ): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -212,28 +208,16 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast - with torch.no_grad(): - ( - prompt_embeds_t5, - negative_prompt_embeds_t5, - prompt_embeds_llama3, - negative_prompt_embeds_llama3, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = pipeline.encode_prompt( - pipeline_args["prompt"], - ) images = [] for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds_t5=prompt_embeds_t5, - prompt_embeds_llama3=prompt_embeds_llama3, - negative_prompt_embeds_t5=negative_prompt_embeds_t5, - negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_embeds_t5=pipeline_args["prompt_embeds_t5"], + prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"], + negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"], + negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"], + pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"], + negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"], generator=generator, ).images[0] images.append(image) @@ -252,9 +236,9 @@ def log_validation( } ) + pipeline.to("cpu") del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images @@ -392,6 +376,14 @@ def parse_args(input_args=None): default=None, help="A prompt that is used during validation to verify that the model is learning.", ) + + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( "--final_validation_prompt", type=str, @@ -1016,6 +1008,7 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) + pipeline.to("cpu") del pipeline free_memory() @@ -1140,7 +1133,7 @@ def main(args): if args.lora_layers is not None: target_modules = [layer.strip() for layer in args.lora_layers.split(",")] else: - target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + target_modules = ["to_k", "to_q", "to_v", "to_out"] # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( @@ -1314,42 +1307,65 @@ def main(args): ) def compute_text_embeddings(prompt, text_encoding_pipeline): - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): - t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = ( - text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) - ) - if args.offload: # back to cpu - text_encoding_pipeline = text_encoding_pipeline.to("cpu") - return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds + ( + t5_prompt_embeds, + negative_prompt_embeds_t5, + llama3_prompt_embeds, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) + return ( + t5_prompt_embeds, + llama3_prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds_t5, + negative_prompt_embeds_llama3, + negative_pooled_prompt_embeds, + ) # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) ( instance_prompt_hidden_states_t5, instance_prompt_hidden_states_llama3, instance_pooled_prompt_embeds, + _, + _, + _, ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") # Handle class prompt for prior-preservation. if args.with_prior_preservation: - ( - class_prompt_hidden_states_t5, - class_prompt_hidden_states_llama3, - class_pooled_prompt_embeds, - ) = compute_text_embeddings(args.class_prompt, text_encoding_pipeline) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( + compute_text_embeddings(args.class_prompt, text_encoding_pipeline) + ) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") - # Clear the memory here - if not train_dataset.custom_instance_prompts: - # delete tokenizers and text ecnoders except for llama (tokenizer & te four) - # as it's needed for inference with pipeline - del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two, tokenizer_three - if not args.validation_prompt: - del tokenizer_four, text_encoder_four - free_memory() + validation_embeddings = {} + if args.validation_prompt is not None: + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + ( + validation_embeddings["prompt_embeds_t5"], + validation_embeddings["prompt_embeds_llama3"], + validation_embeddings["pooled_prompt_embeds"], + validation_embeddings["negative_prompt_embeds_t5"], + validation_embeddings["negative_prompt_embeds_llama3"], + validation_embeddings["negative_pooled_prompt_embeds"], + ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1367,20 +1383,52 @@ def main(args): vae_config_scaling_factor = vae.config.scaling_factor vae_config_shift_factor = vae.config.shift_factor - if args.cache_latents: + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + t5_prompt_cache = [] + llama3_prompt_cache = [] + pooled_prompt_cache = [] latents_cache = [] if args.offload: vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): - batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=vae.dtype - ) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if args.cache_latents: + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds, _, _, _ = compute_text_embeddings( + batch["prompts"], text_encoding_pipeline + ) + t5_prompt_cache.append(t5_prompt_embeds) + llama3_prompt_cache.append(llama3_prompt_embeds) + pooled_prompt_cache.append(pooled_prompt_embeds) - if args.validation_prompt is None: + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.offload or args.cache_latents: + vae = vae.to("cpu") + if args.cache_latents: del vae - free_memory() + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del ( + text_encoder_one, + text_encoder_two, + text_encoder_three, + text_encoder_four, + tokenizer_two, + tokenizer_three, + tokenizer_four, + text_encoding_pipeline, + ) + free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1487,9 +1535,9 @@ def main(args): with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( - prompts, text_encoding_pipeline - ) + t5_prompt_embeds = t5_prompt_cache[step] + llama3_prompt_embeds = llama3_prompt_cache[step] + pooled_prompt_embeds = pooled_prompt_cache[step] else: t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1) llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1) @@ -1619,26 +1667,30 @@ def main(args): # create pipeline pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, - tokenizer_4=tokenizer_four, - text_encoder_4=text_encoder_four, + tokenizer=None, + text_encoder=None, + tokenizer_2=None, + text_encoder_2=None, + tokenizer_3=None, + text_encoder_3=None, + tokenizer_4=None, + text_encoder_4=None, transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) - pipeline_args = {"prompt": args.validation_prompt} images = log_validation( pipeline=pipeline, args=args, accelerator=accelerator, - pipeline_args=pipeline_args, + pipeline_args=validation_embeddings, torch_dtype=weight_dtype, epoch=epoch, ) - free_memory() - - images = None del pipeline + images = None + free_memory() # Save the lora layers accelerator.wait_for_everyone() @@ -1655,50 +1707,49 @@ def main(args): transformer_lora_layers=transformer_lora_layers, ) - # Final inference - # Load previous pipeline - tokenizer_4 = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_4_name_or_path) - tokenizer_4.pad_token = tokenizer_4.eos_token - text_encoder_4 = LlamaForCausalLM.from_pretrained( - args.pretrained_text_encoder_4_name_or_path, - output_hidden_states=True, - output_attentions=True, - torch_dtype=torch.bfloat16, - ) - pipeline = HiDreamImagePipeline.from_pretrained( - args.pretrained_model_name_or_path, - tokenizer_4=tokenizer_4, - text_encoder_4=text_encoder_4, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference images = [] - if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt): - prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt - args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 - pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images} + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + # Final inference + # Load previous pipeline + pipeline = HiDreamImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=None, + text_encoder=None, + tokenizer_2=None, + text_encoder_2=None, + tokenizer_3=None, + text_encoder_3=None, + tokenizer_4=None, + text_encoder_4=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference images = log_validation( pipeline=pipeline, args=args, accelerator=accelerator, - pipeline_args=pipeline_args, + pipeline_args=validation_embeddings, epoch=epoch, is_final_validation=True, torch_dtype=weight_dtype, ) + del pipeline + free_memory() - validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt save_model_card( (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, images=images, base_model=args.pretrained_model_name_or_path, instance_prompt=args.instance_prompt, - validation_prompt=validation_prpmpt, + validation_prompt=validation_prompt, repo_folder=args.output_dir, ) @@ -1711,7 +1762,6 @@ def main(args): ) images = None - del pipeline accelerator.end_training()