From 413ca29b718592b89a8cd06c81d4e4373175d2ee Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 12 Aug 2024 09:28:03 +0300 Subject: [PATCH] [Flux Dreambooth LoRA] - te bug fixes & updates (#9139) * add requirements + fix link to bghira's guide * text ecnoder training fixes * text encoder training fixes * text encoder training fixes * text encoder training fixes * style * add tests * fix encode_prompt call * style * unpack_latents test * fix lora saving * remove default val for max_sequenece_length in encode_prompt * remove default val for max_sequenece_length in encode_prompt * style * testing * style * testing * testing * style * fix sizing issue * style * revert scaling * style * style * scaling test * style * scaling test * remove model pred operation left from pre-conditioning * remove model pred operation left from pre-conditioning * fix trainable params * remove te2 from casting * transformer to accelerator * remove prints * empty commit --- examples/dreambooth/README_flux.md | 8 +- examples/dreambooth/requirements_flux.txt | 8 + examples/dreambooth/test_dreambooth_flux.py | 203 ++++++++++++++++++ .../dreambooth/test_dreambooth_lora_flux.py | 165 ++++++++++++++ examples/dreambooth/train_dreambooth_flux.py | 14 +- .../dreambooth/train_dreambooth_lora_flux.py | 182 +++++++++------- 6 files changed, 488 insertions(+), 92 deletions(-) create mode 100644 examples/dreambooth/requirements_flux.txt create mode 100644 examples/dreambooth/test_dreambooth_flux.py create mode 100644 examples/dreambooth/test_dreambooth_lora_flux.py diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 4c0ba7fbaa..fab382c089 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md) +> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) > [!NOTE] @@ -96,7 +96,7 @@ accelerate launch train_dreambooth_flux.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ - --mixed_precision="fp16" \ + --mixed_precision="bf16" \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ @@ -140,7 +140,7 @@ accelerate launch train_dreambooth_lora_flux.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ - --mixed_precision="fp16" \ + --mixed_precision="bf16" \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ @@ -175,7 +175,7 @@ accelerate launch train_dreambooth_lora_flux.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ - --mixed_precision="fp16" \ + --mixed_precision="bf16" \ --train_text_encoder\ --instance_prompt="a photo of sks dog" \ --resolution=512 \ diff --git a/examples/dreambooth/requirements_flux.txt b/examples/dreambooth/requirements_flux.txt new file mode 100644 index 0000000000..dbc124ff65 --- /dev/null +++ b/examples/dreambooth/requirements_flux.txt @@ -0,0 +1,8 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece \ No newline at end of file diff --git a/examples/dreambooth/test_dreambooth_flux.py b/examples/dreambooth/test_dreambooth_flux.py new file mode 100644 index 0000000000..2d5703d2a2 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_flux.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile + +from diffusers import DiffusionPipeline, FluxTransformer2DModel + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothFlux(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + script_path = "examples/dreambooth/train_dreambooth_flux.py" + + def test_dreambooth(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_dreambooth_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check can run the original fully trained output pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir) + pipe(self.instance_prompt, num_inference_steps=1) + + # check checkpoint directories exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + + # check can run an intermediate checkpoint + transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer") + pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer) + pipe(self.instance_prompt, num_inference_steps=1) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 7 total steps resuming from checkpoint 4 + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir) + pipe(self.instance_prompt, num_inference_steps=1) + + # check old checkpoints do not exist + self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + + # check new checkpoints exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + + def test_dreambooth_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py new file mode 100644 index 0000000000..b77f84447a --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFlux(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_flux.py" + + def test_dreambooth_lora_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_text_encoder_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 66e7a4e97f..df4a44abde 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1505,6 +1505,9 @@ def main(args): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], @@ -1583,16 +1586,11 @@ def main(args): model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2]) * 8, - width=int(model_input.shape[3]) * 8, - vae_scale_factor=2 - ** ( - len(vae.config.block_out_channels) - ), # should this be 2 ** (len(vae.config.block_out_channels))? + height=int(model_input.shape[2]), + width=int(model_input.shape[3]), + vae_scale_factor=vae_scale_factor, ) - model_pred = model_pred * (-sigmas) + noisy_model_input - # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e32004a8d8..3629fcca4d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -319,7 +319,7 @@ def parse_args(input_args=None): parser.add_argument( "--max_sequence_length", type=int, - default=77, + default=512, help="Maximum sequence length to use with with the T5 text encoder", ) parser.add_argument( @@ -864,7 +864,7 @@ class PromptDataset(Dataset): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): +def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", @@ -885,20 +885,26 @@ def _encode_prompt_with_t5( prompt=None, num_images_per_prompt=1, device=None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -918,22 +924,28 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel @@ -954,6 +966,7 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -965,6 +978,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -974,6 +988,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1127,14 +1142,11 @@ def main(args): args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) + # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) - if args.train_text_encoder: - text_encoder_one.requires_grad_(True) - text_encoder_two.requires_grad_(False) - else: - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -1151,9 +1163,9 @@ def main(args): ) vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1168,6 +1180,14 @@ def main(args): target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1257,15 +1277,16 @@ def main(args): if args.mixed_precision == "fp16": models = [transformer] if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) + models.extend([text_encoder_one]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - # Optimization parameters - transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1420,14 +1441,18 @@ def main(args): prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) - # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the - # batch prompts on all training steps + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps else: tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) + tokens_two = tokenize_prompt( + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + ) if args.with_prior_preservation: class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) @@ -1545,6 +1570,8 @@ def main(args): transformer.train() if args.train_text_encoder: text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1562,12 +1589,33 @@ def main(args): ) else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=prompts, + ) + else: + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=args.instance_prompt, + ) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], @@ -1575,7 +1623,6 @@ def main(args): accelerator.device, weight_dtype, ) - # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1613,49 +1660,24 @@ def main(args): guidance = None # Predict the noise residual - if not args.train_text_encoder: - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - else: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=None, - prompt=None, - text_input_ids_list=[tokens_one, tokens_two], - ) - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2]) * 8, - width=int(model_input.shape[3]) * 8, - vae_scale_factor=2 - ** ( - len(vae.config.block_out_channels) - ), # should this be 2 ** (len(vae.config.block_out_channels))? + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, ) - model_pred = model_pred * (-sigmas) + noisy_model_input - # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) @@ -1783,7 +1805,7 @@ def main(args): FluxPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, - text_encoder_one_lora_layers=text_encoder_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, ) # Final inference