From dd285099ebe556da550dc0b7c2130cc829ce6395 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Wed, 25 Jun 2025 16:32:17 +0800 Subject: [PATCH 01/15] adjust to get CI test cases passed on XPU (#11759) * adjust to get CI test cases passed on XPU Signed-off-by: Liu, Kaixuan * fix format issue Signed-off-by: Liu, Kaixuan * Apply style fixes --------- Signed-off-by: Liu, Kaixuan Co-authored-by: github-actions[bot] Co-authored-by: Aryan --- .../kandinsky2_2/test_kandinsky_controlnet.py | 3 +- .../test_ledits_pp_stable_diffusion.py | 33 +++++++++++- tests/pipelines/test_pipelines_common.py | 6 +-- tests/quantization/gguf/test_gguf.py | 50 +++++++++---------- 4 files changed, 60 insertions(+), 32 deletions(-) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py index 84085f9d7d..b2d6f0fc05 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py @@ -289,6 +289,5 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase): image = output.images[0] assert image.shape == (512, 512, 3) - max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) - assert max_diff < 1e-4 + assert max_diff < 2e-4 diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py index 342561d4f5..ab0221dc81 100644 --- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py +++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py @@ -29,6 +29,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + Expectations, backend_empty_cache, enable_full_determinism, floats_tensor, @@ -244,7 +245,35 @@ class LEditsPPPipelineStableDiffusionSlowTests(unittest.TestCase): output_slice = reconstruction[150:153, 140:143, -1] output_slice = output_slice.flatten() - expected_slice = np.array( - [0.9453125, 0.93310547, 0.84521484, 0.94628906, 0.9111328, 0.80859375, 0.93847656, 0.9042969, 0.8144531] + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.9511719, + 0.94140625, + 0.87597656, + 0.9472656, + 0.9296875, + 0.8378906, + 0.94433594, + 0.91503906, + 0.8491211, + ] + ), + ("cuda", 7): np.array( + [ + 0.9453125, + 0.93310547, + 0.84521484, + 0.94628906, + 0.9111328, + 0.80859375, + 0.93847656, + 0.9042969, + 0.8144531, + ] + ), + } ) + expected_slice = expected_slices.get_expectation() assert np.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 207cff2a3c..4a3a9b1796 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -49,6 +49,7 @@ from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor from diffusers.utils.testing_utils import ( CaptureLogger, backend_empty_cache, + numpy_cosine_similarity_distance, require_accelerate_version_greater, require_accelerator, require_hf_hub_version_greater, @@ -1394,9 +1395,8 @@ class PipelineTesterMixin: fp16_inputs["generator"] = self.get_generator(0) output_fp16 = pipe_fp16(**fp16_inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() - self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) + assert max_diff < 2e-4 @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index ae3900459d..5d1fa4c22e 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -286,33 +286,33 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase) { ("xpu", 3): np.array( [ - 0.19335938, - 0.3125, - 0.3203125, - 0.1328125, - 0.3046875, - 0.296875, - 0.11914062, - 0.2890625, - 0.2890625, - 0.16796875, - 0.30273438, - 0.33203125, - 0.14648438, - 0.31640625, - 0.33007812, + 0.16210938, + 0.2734375, + 0.27734375, + 0.109375, + 0.27148438, + 0.2578125, + 0.1015625, + 0.2578125, + 0.2578125, + 0.14453125, + 0.26953125, + 0.29492188, 0.12890625, - 0.3046875, - 0.30859375, - 0.17773438, - 0.33789062, - 0.33203125, - 0.16796875, - 0.34570312, - 0.32421875, + 0.28710938, + 0.30078125, + 0.11132812, + 0.27734375, + 0.27929688, 0.15625, - 0.33203125, - 0.31445312, + 0.31054688, + 0.296875, + 0.15234375, + 0.3203125, + 0.29492188, + 0.140625, + 0.3046875, + 0.28515625, ] ), ("cuda", 7): np.array( From 88466358733da21a4ab45d85300ee6960f588e7d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Jun 2025 00:18:20 +0530 Subject: [PATCH 02/15] fix deprecation in lora after 0.34.0 release (#11802) --- src/diffusers/loaders/lora_base.py | 12 -------- tests/lora/test_deprecated_utilities.py | 39 ------------------------- 2 files changed, 51 deletions(-) delete mode 100644 tests/lora/test_deprecated_utilities.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 16f0d48365..e6941a521d 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1022,15 +1022,3 @@ class LoraBaseMixin: @classmethod def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) - - @classmethod - def _fetch_state_dict(cls, *args, **kwargs): - deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." - deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - return _fetch_state_dict(*args, **kwargs) - - @classmethod - def _best_guess_weight_name(cls, *args, **kwargs): - deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." - deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - return _best_guess_weight_name(*args, **kwargs) diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py deleted file mode 100644 index 4275ef8089..0000000000 --- a/tests/lora/test_deprecated_utilities.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import tempfile -import unittest - -import torch - -from diffusers.loaders.lora_base import LoraBaseMixin - - -class UtilityMethodDeprecationTests(unittest.TestCase): - def test_fetch_state_dict_cls_method_raises_warning(self): - state_dict = torch.nn.Linear(3, 3).state_dict() - with self.assertWarns(FutureWarning) as warning: - _ = LoraBaseMixin._fetch_state_dict( - state_dict, - weight_name=None, - use_safetensors=False, - local_files_only=True, - cache_dir=None, - force_download=False, - proxies=None, - token=None, - revision=None, - subfolder=None, - user_agent=None, - allow_pickle=None, - ) - warning_message = str(warning.warnings[0].message) - assert "Using the `_fetch_state_dict()` method from" in warning_message - - def test_best_guess_weight_name_cls_method_raises_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - state_dict = torch.nn.Linear(3, 3).state_dict() - torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) - - with self.assertWarns(FutureWarning) as warning: - _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) - warning_message = str(warning.warnings[0].message) - assert "Using the `_best_guess_weight_name()` method from" in warning_message From 10c36e0b782af500dd2b30e94ad0a4766230eaf7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Jun 2025 06:56:46 +0530 Subject: [PATCH 03/15] [chore] post release v0.34.0 (#11800) * post release v0.34.0 * code quality --------- Co-authored-by: YiYi Xu --- .../train_dreambooth_lora_flux_advanced.py | 2 +- .../train_dreambooth_lora_sd15_advanced.py | 2 +- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- examples/cogvideo/train_cogvideox_image_to_video_lora.py | 2 +- examples/cogvideo/train_cogvideox_lora.py | 2 +- examples/cogview4-control/train_control_cogview4.py | 2 +- examples/community/marigold_depth_estimation.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sd_wds.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sd_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sdxl_wds.py | 2 +- examples/controlnet/train_controlnet.py | 2 +- examples/controlnet/train_controlnet_flax.py | 2 +- examples/controlnet/train_controlnet_flux.py | 2 +- examples/controlnet/train_controlnet_sd3.py | 2 +- examples/controlnet/train_controlnet_sdxl.py | 2 +- examples/custom_diffusion/train_custom_diffusion.py | 2 +- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_flax.py | 2 +- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- examples/dreambooth/train_dreambooth_lora_lumina2.py | 2 +- examples/dreambooth/train_dreambooth_lora_sana.py | 2 +- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- examples/dreambooth/train_dreambooth_sd3.py | 2 +- examples/flux-control/train_control_flux.py | 2 +- examples/flux-control/train_control_lora_flux.py | 2 +- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 2 +- .../kandinsky2_2/text_to_image/train_text_to_image_decoder.py | 2 +- .../text_to_image/train_text_to_image_lora_decoder.py | 2 +- .../text_to_image/train_text_to_image_lora_prior.py | 2 +- .../kandinsky2_2/text_to_image/train_text_to_image_prior.py | 2 +- examples/t2i_adapter/train_t2i_adapter_sdxl.py | 2 +- examples/text_to_image/train_text_to_image.py | 2 +- examples/text_to_image/train_text_to_image_flax.py | 2 +- examples/text_to_image/train_text_to_image_lora.py | 2 +- examples/text_to_image/train_text_to_image_lora_sdxl.py | 2 +- examples/text_to_image/train_text_to_image_sdxl.py | 2 +- examples/textual_inversion/textual_inversion.py | 2 +- examples/textual_inversion/textual_inversion_flax.py | 2 +- examples/textual_inversion/textual_inversion_sdxl.py | 2 +- examples/unconditional_image_generation/train_unconditional.py | 2 +- examples/vqgan/train_vqgan.py | 2 +- setup.py | 2 +- src/diffusers/__init__.py | 2 +- 50 files changed, 50 insertions(+), 50 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 173d3bfd5b..2b892a91ae 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -75,7 +75,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 52aee07e81..2c4682d62a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 911102c049..7f88d1cbdd 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -80,7 +80,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 315c61c60b..47245ed896 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -61,7 +61,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index a8e73e938c..caa970d4bf 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -52,7 +52,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 7d2ce20949..9b2f22452b 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -59,7 +59,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 453735411d..8be773c138 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") class MarigoldDepthOutput(BaseOutput): diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index b254799756..bedd64da74 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -73,7 +73,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 554319aef4..113a374c12 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -66,7 +66,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 52d4806100..cd50ff176c 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -79,7 +79,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 3be506352f..e223b71aea 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -72,7 +72,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 5a28201bf7..20d5c59cc1 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -78,7 +78,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 69bd39944a..1ddbe5c56a 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -60,7 +60,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 5561710d6f..90fe426b49 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -60,7 +60,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 94f030fe01..cde1c4d0be 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -65,7 +65,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index ecd7572ca3..746063f9d6 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -61,7 +61,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 76d232da1c..03296a81f0 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -61,7 +61,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 81992c3dd1..83ea952299 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ec0cc686b0..15f59569b8 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -63,7 +63,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 4e61a04f24..ccf4626cf8 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -35,7 +35,7 @@ from diffusers.utils import check_min_version # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 02b83bb6b1..c575cf654e 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -65,7 +65,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 7c008970bd..d882bac0e7 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -74,7 +74,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9c529cbb92..e5bade0b7e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -72,7 +72,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a1337e8dba..965bb554a8 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -73,7 +73,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.33.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index da499bce71..5128e87166 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -72,7 +72,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index c156523db3..d84a532a15 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -72,7 +72,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 05dfe6301f..c049f9b482 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -72,7 +72,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c3dfc923f0..12f8ab3602 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -79,7 +79,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 8d5dee0188..e96e4844cc 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -63,7 +63,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 3be0182f6d..bce1c6626b 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -54,7 +54,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 34755209ce..3c8b75a088 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -57,7 +57,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 9f536139ab..62ee176101 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -58,7 +58,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 1b01f61738..74735c94ec 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -60,7 +60,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index f2c5047d75..acc305384b 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -52,7 +52,7 @@ if is_wandb_available(): # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 5b39c25901..15b215ac24 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 8c31f8f03b..904115410b 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index 1f16c2d21a..ae5a807eba 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ if is_wandb_available(): # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index cb8fade444..a5a8c5e2eb 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -60,7 +60,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 17f5dc852b..7b5cd63758 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -57,7 +57,7 @@ if is_wandb_available(): # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index d9c1aafe80..1eaed236fe 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -49,7 +49,7 @@ from diffusers.utils import check_min_version # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 89f867b5ba..01fcb38c74 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -56,7 +56,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 12afb72b9a..485b283978 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -68,7 +68,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 65a6131e66..f31971a816 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 6dcc2ff7dc..a415b288d8 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -81,7 +81,7 @@ else: # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 44c46995a1..d26ab492cd 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ else: # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index add15a8583..1cfe7969ec 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -76,7 +76,7 @@ else: # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index baf2a9d899..892c674575 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index a14ca13495..5ba1678d44 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -50,7 +50,7 @@ if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.34.0.dev0") +check_min_version("0.35.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/setup.py b/setup.py index e8df544e0c..11b64cff2d 100644 --- a/setup.py +++ b/setup.py @@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1 setup( name="diffusers", - version="0.34.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 81051b9f25..3111a72e52 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.34.0.dev0" +__version__ = "0.35.0.dev0" from typing import TYPE_CHECKING From 3649d7b903b244939e684c072e756204c5f23224 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Jun 2025 03:54:24 +0200 Subject: [PATCH 04/15] Follow up for Group Offload to Disk (#11760) * update * update * update --------- Co-authored-by: Sayak Paul --- src/diffusers/hooks/group_offloading.py | 134 ++++++++++++++---------- 1 file changed, 77 insertions(+), 57 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7186cb181a..45fee35ef3 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -132,9 +132,58 @@ class ModuleGroup: finally: pinned_dict = None + def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): + tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream and current_stream is not None: + tensor.data.record_stream(current_stream) + + def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): + for group_module in self.modules: + for param in group_module.parameters(): + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + for buffer in group_module.buffers(): + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + + for param in self.parameters: + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + + for buffer in self.buffers: + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + + def _onload_from_disk(self, current_stream): + if self.stream is not None: + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + + for key, tensor_obj in self.key_to_tensor.items(): + self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] + + with self._pinned_memory_tensors() as pinned_memory: + for key, tensor_obj in self.key_to_tensor.items(): + self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) + + self.cpu_param_dict.clear() + + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + + def _onload_from_memory(self, current_stream): + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + self._process_tensors_from_modules(pinned_memory, current_stream) + else: + self._process_tensors_from_modules(None, current_stream) + @torch.compiler.disable() def onload_(self): - r"""Onloads the group of modules to the onload_device.""" torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -172,67 +221,30 @@ class ModuleGroup: self.stream.synchronize() with context: - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) - for buffer in group_module.buffers(): - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - - for param in self.parameters: - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) - - for buffer in self.buffers: - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - + if self.offload_to_disk_path: + self._onload_from_disk(current_stream) else: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for buffer in group_module.buffers(): - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + self._onload_from_memory(current_stream) - for param in self.parameters: - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + def _offload_to_disk(self): + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - for buffer in self.buffers: - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True - @torch.compiler.disable() - def offload_(self): - r"""Offloads the group of modules to the offload_device.""" - if self.offload_to_disk_path: - # TODO: we can potentially optimize this code path by checking if the _all_ the desired - # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not - # we perform a write. - # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): - os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = { - key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() - } - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - - # The group is now considered offloaded to disk for the rest of the session. - self._is_offloaded_to_disk = True - - # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - return + # We do this to free up the RAM which is still holding the up tensor data. + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + def _offload_to_memory(self): torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -257,6 +269,14 @@ class ModuleGroup: for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + @torch.compiler.disable() + def offload_(self): + r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk_path: + self._offload_to_disk() + else: + self._offload_to_memory() + class GroupOffloadingHook(ModelHook): r""" From d93381cd417b5205cc1703856f56f15a43b3c8c4 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 25 Jun 2025 20:11:38 -0700 Subject: [PATCH 05/15] [rfc][compile] compile method for DiffusionPipeline (#11705) * [rfc][compile] compile method for DiffusionPipeline * Apply suggestions from code review Co-authored-by: Sayak Paul * Apply style fixes * Update docs/source/en/optimization/fp16.md * check --------- Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- docs/source/en/optimization/fp16.md | 38 +++++++++++++++++-- src/diffusers/models/modeling_utils.py | 34 +++++++++++++++++ .../models/transformers/transformer_chroma.py | 1 + .../models/transformers/transformer_flux.py | 1 + .../transformers/transformer_hunyuan_video.py | 6 +++ .../models/transformers/transformer_ltx.py | 1 + .../models/transformers/transformer_wan.py | 1 + .../models/unets/unet_2d_condition.py | 1 + tests/models/test_modeling_common.py | 21 ++++++++++ 9 files changed, 101 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 2e12bfadcf..45a2282ba1 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste ### Regional compilation -[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks. -[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately. +[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence. +For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**. + +To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable: + +```py +# pip install -U diffusers +import torch +from diffusers import StableDiffusionXLPipeline + +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, +).to("cuda") + +# Compile only the repeated Transformer layers inside the UNet +pipe.unet.compile_repeated_blocks(fullgraph=True) +``` + +To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled: + + +```py +class MyUNet(ModelMixin): + _repeated_blocks = ("Transformer2DModel",) # ← compiled by default +``` + +For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705). + +**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags. + + ```py # pip install -U accelerate @@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained( ).to("cuda") pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True) ``` +`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users. + ### Graph breaks @@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric ```py pipeline.fuse_qkv_projections() -``` \ No newline at end of file +``` diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5fa04fb260..8e1ec5f558 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None _supports_group_offloading = True + _repeated_blocks = [] def __init__(self): super().__init__() @@ -1404,6 +1405,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): else: return super().float(*args) + def compile_repeated_blocks(self, *args, **kwargs): + """ + Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of + compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe + https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time + substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`. + + The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the + model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every + module whose class name matches will be compiled. + + Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any + positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to + `torch.compile`. + """ + repeated_blocks = getattr(self, "_repeated_blocks", None) + + if not repeated_blocks: + raise ValueError( + "`_repeated_blocks` attribute is empty. " + f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. " + ) + has_compiled_region = False + for submod in self.modules(): + if submod.__class__.__name__ in repeated_blocks: + submod.compile(*args, **kwargs) + has_compiled_region = True + + if not has_compiled_region: + raise ValueError( + f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " + ) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d11f6c2a5e..0f6dd677ac 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -407,6 +407,7 @@ class ChromaTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] + _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ab579a0eb5..3af1de2ad0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -227,6 +227,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index c48c586a28..6944a6c536 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, "HunyuanVideoPatchEmbed", "HunyuanVideoTokenRefiner", ] + _repeated_blocks = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 38b7b6af50..2d06124282 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede418..0ae7f2c00d 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 0cf5133c54..0f789d3961 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -167,6 +167,7 @@ class UNet2DConditionModel( _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["BasicTransformerBlock"] @register_to_config def __init__( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3a401c46fb..7e1e1483f7 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1935,6 +1935,27 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_torch_compile_repeated_blocks(self): + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model.compile_repeated_blocks(fullgraph=True) + + recompile_limit = 1 + if self.model_class.__name__ == "UNet2DConditionModel": + recompile_limit = 2 + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(recompile_limit=recompile_limit), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + def test_compile_with_group_offloading(self): torch._dynamo.config.cache_size_limit = 10000 From a185e1ab91926e6143b9732badf49e124b215ab7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Jun 2025 10:07:03 +0530 Subject: [PATCH 06/15] [tests] add a test on torch compile for varied resolutions (#11776) * add test for checking compile on different shapes. * update * update * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/fp16.md | 22 +++++++++++++++++ tests/models/test_modeling_common.py | 24 ++++++++++++++++--- .../test_models_transformer_flux.py | 24 ++++++++++++------- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 45a2282ba1..734f63e68d 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -150,6 +150,28 @@ pipeline(prompt, num_inference_steps=30).images[0] Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient. +### Dynamic shape compilation + +> [!TIP] +> Make sure to always use the nightly version of PyTorch for better support. + +`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation. + +To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change. + +```diff ++ torch.fx.experimental._config.use_duck_shape = False ++ pipeline.unet = torch.compile( + pipeline.unet, fullgraph=True, dynamic=True +) +``` + +Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + +Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation. + +Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model. + ### Regional compilation diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 7e1e1483f7..dcc7ae16a4 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -76,6 +76,7 @@ from diffusers.utils.testing_utils import ( require_torch_accelerator_with_training, require_torch_gpu, require_torch_multi_accelerator, + require_torch_version_greater, run_test_in_subprocess, slow, torch_all_close, @@ -1907,6 +1908,8 @@ class ModelPushToHubTester(unittest.TestCase): @is_torch_compile @slow class TorchCompileTesterMixin: + different_shapes_for_compilation = None + def setUp(self): # clean up the VRAM before each test super().setUp() @@ -1957,14 +1960,14 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) def test_compile_with_group_offloading(self): + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + torch._dynamo.config.cache_size_limit = 10000 init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if not getattr(model, "_supports_group_offloading", True): - return - model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { @@ -1981,6 +1984,21 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) _ = model(**inputs_dict) + @require_torch_version_greater("2.7.1") + def test_compile_on_different_shapes(self): + if self.different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + torch.fx.experimental._config.use_duck_shape = False + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model = torch.compile(model, fullgraph=True, dynamic=True) + + for height, width in self.different_shapes_for_compilation: + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**inputs_dict) + @slow @require_torch_2 diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 0a55236ef1..4552b2e1f5 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): @property def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_dummy_input(self, height=4, width=4): batch_size = 1 num_latent_channels = 4 num_image_channels = 3 - height = width = 4 sequence_length = 48 embedding_dim = 32 @@ -114,14 +124,6 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): "timestep": timestep, } - @property - def input_shape(self): - return (16, 4) - - @property - def output_shape(self): - return (16, 4) - def prepare_init_args_and_inputs_for_common(self): init_dict = { "patch_size": 1, @@ -173,10 +175,14 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel From 27bf7fcd0e8069c623b564dd7024ea782b69dca8 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 26 Jun 2025 15:49:59 +0800 Subject: [PATCH 07/15] adjust tolerance criteria for `test_float16_inference` in unit test (#11809) Signed-off-by: Liu, Kaixuan --- tests/pipelines/amused/test_amused_inpaint.py | 33 ++++++++++++++++++- tests/pipelines/test_pipelines_common.py | 2 +- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index 62f39de8c3..0b025b8a3f 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -22,6 +22,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel from diffusers.utils import load_image from diffusers.utils.testing_utils import ( + Expectations, enable_full_determinism, require_torch_accelerator, slow, @@ -246,5 +247,35 @@ class AmusedInpaintPipelineSlowTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1].flatten() assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.0227, 0.0157, 0.0098, 0.0213, 0.0250, 0.0127, 0.0280, 0.0380, 0.0095]) + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.0274, + 0.0211, + 0.0154, + 0.0257, + 0.0299, + 0.0170, + 0.0326, + 0.0420, + 0.0150, + ] + ), + ("cuda", 7): np.array( + [ + 0.0227, + 0.0157, + 0.0098, + 0.0213, + 0.0250, + 0.0127, + 0.0280, + 0.0380, + 0.0095, + ] + ), + } + ) + expected_slice = expected_slices.get_expectation() assert np.abs(image_slice - expected_slice).max() < 0.003 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 4a3a9b1796..69dd79bb56 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1396,7 +1396,7 @@ class PipelineTesterMixin: output_fp16 = pipe_fp16(**fp16_inputs)[0] max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) - assert max_diff < 2e-4 + assert max_diff < 1e-2 @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator From eea76892e87c600d98934c987c7e067640a9ce16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 26 Jun 2025 21:29:59 +0530 Subject: [PATCH 08/15] Flux Kontext (#11812) * support flux kontext * make fix-copies * add example * add tests * update docs * update * add note on integrity checker * make fix-copies issue * add copied froms * make style * update repository ids * more copied froms --- docs/source/en/api/pipelines/flux.md | 41 + src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/flux/__init__.py | 2 + .../pipelines/flux/pipeline_flux_kontext.py | 1129 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + .../flux/test_pipeline_flux_kontext.py | 177 +++ 7 files changed, 1368 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_kontext.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_kontext.py diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index ef29e77ce2..d8a86a0692 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -39,6 +39,7 @@ Flux comes in the following variants: | Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) | | Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) | | Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) | +| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-kontext) | All checkpoints have different usage which we detail below. @@ -273,6 +274,46 @@ images = pipe( images[0].save("flux-redux.png") ``` +### Kontext + +Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more. + +```python +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image + +pipe = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png").convert("RGB") +prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" +image = pipe( + image=image, + prompt=prompt, + guidance_scale=2.5, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("flux-kontext.png") +``` + +Flux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code: + +```python +from flux.safety import PixtralIntegrity + +# ... pipeline invocation to generate images + +integrity_checker = PixtralIntegrity(torch.device("cuda")) +image_ = np.array(image) / 255.0 +image_ = 2 * image_ - 1 +image_ = torch.from_numpy(image_).to("cuda", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2) +if integrity_checker.test_image(image_): + raise ValueError("Your image has been flagged. Choose another prompt/image or try again.") +``` + ## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD). diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3111a72e52..b3f5f6ec9d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -381,6 +381,7 @@ else: "FluxFillPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", + "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", "HiDreamImagePipeline", @@ -974,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, + FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, HiDreamImagePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b32d55bd51..892c6f5a4c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -140,6 +140,7 @@ else: "FluxFillPipeline", "FluxPriorReduxPipeline", "ReduxImageEncoder", + "FluxKontextPipeline", ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -609,6 +610,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, + FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, ReduxImageEncoder, diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 72e1b578f2..117ce46f20 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -33,6 +33,7 @@ else: _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + _import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"] _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -52,6 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_flux_fill import FluxFillPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline + from .pipeline_flux_kontext import FluxKontextPipeline from .pipeline_flux_prior_redux import FluxPriorReduxPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py new file mode 100644 index 0000000000..2427e342a9 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -0,0 +1,1129 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxKontextPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxKontextPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image("inputs/yarn-art-pikachu.png").convert("RGB") + >>> prompt = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + >>> image = pipe( + ... image=image, + ... prompt=prompt, + ... guidance_scale=2.5, + ... generator=torch.Generator().manual_seed(42), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxKontextPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux Kontext pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latent_ids, image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + max_area (`int`, defaults to `1024 ** 2`): + The maximum area of the generated image in pixels. The height and width will be adjusted to fit this + area while maintaining the aspect ratio. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + image = self.image_processor.preprocess(image, image_height, image_width) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, latent_ids, image_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 656a8ac6c6..a0c6d84a32 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -692,6 +692,21 @@ class FluxInpaintPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class FluxKontextPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py new file mode 100644 index 0000000000..7471d78ad5 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py @@ -0,0 +1,177 @@ +import unittest + +import numpy as np +import PIL.Image +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + FluxKontextPipeline, + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + FluxIPAdapterTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, +) + + +class FluxKontextPipelineFastTests( + unittest.TestCase, + PipelineTesterMixin, + FluxIPAdapterTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, +): + pipeline_class = FluxKontextPipeline + params = frozenset( + ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"] + ) + batch_params = frozenset(["image", "prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "image_encoder": None, + "feature_extractor": None, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = PIL.Image.new("RGB", (32, 32), 0) + inputs = { + "image": image, + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_area": 8 * 8, + "max_sequence_length": 48, + "output_type": "np", + "_auto_resize": False, + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width, "max_area": height * width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_flux_true_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("generator") + + no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + inputs["negative_prompt"] = "bad quality" + inputs["true_cfg_scale"] = 2.0 + true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + assert not np.allclose(no_true_cfg_out, true_cfg_out) From 00f95b9755718aabb65456e791b8408526ae6e76 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Jun 2025 22:01:42 +0530 Subject: [PATCH 09/15] Kontext training (#11813) * support flux kontext * make fix-copies * add example * add tests * update docs * update * add note on integrity checker * initial commit * initial commit * add readme section and fixes in the training script. * add test * rectify ckpt_id * fix ckpt * fixes * change id * update * Update examples/dreambooth/train_dreambooth_lora_flux_kontext.py Co-authored-by: Aryan * Update examples/dreambooth/README_flux.md --------- Co-authored-by: Aryan Co-authored-by: linoytsaban Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- docs/source/en/api/pipelines/flux.md | 2 +- examples/dreambooth/README_flux.md | 46 + .../test_dreambooth_lora_flux_kontext.py | 281 +++ .../train_dreambooth_lora_flux_kontext.py | 2069 +++++++++++++++++ src/diffusers/training_utils.py | 42 + 5 files changed, 2439 insertions(+), 1 deletion(-) create mode 100644 examples/dreambooth/test_dreambooth_lora_flux_kontext.py create mode 100644 examples/dreambooth/train_dreambooth_lora_flux_kontext.py diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index d8a86a0692..1c41b85ca9 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -39,7 +39,7 @@ Flux comes in the following variants: | Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) | | Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) | | Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) | -| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-kontext) | +| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | All checkpoints have different usage which we detail below. diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index a3704f2789..24c71d5c56 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -260,5 +260,51 @@ to enable `latent_caching` simply pass `--cache_latents`. By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. +## Training Kontext + +[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We +provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too. + +Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section. + +Below is an example training command: + +```bash +accelerate launch train_dreambooth_lora_flux_kontext.py \ + --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \ + --instance_data_dir="dog" \ + --output_dir="kontext-dog" \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --optimizer="adamw" \ + --use_8bit_adam \ + --cache_latents \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --seed="0" +``` + +Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not +perform as expected. + +### Misc notes + +* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it. +### Aspect Ratio Bucketing +we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency. + +To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as: + +`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672" +` +Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 + ## Other notes Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/examples/dreambooth/test_dreambooth_lora_flux_kontext.py b/examples/dreambooth/test_dreambooth_lora_flux_kontext.py new file mode 100644 index 0000000000..c12fdd79ee --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_flux_kontext.py @@ -0,0 +1,281 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +import tempfile + +import safetensors + +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_flux_kontext.py" + transformer_layer_type = "single_transformer_blocks.0.attn.to_k" + + def test_dreambooth_lora_flux_kontext(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_text_encoder_flux_kontext(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict + starts_with_transformer = all( + key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py new file mode 100644 index 0000000000..9f97567b06 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -0,0 +1,2069 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import itertools +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxKontextPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + parse_buckets_string, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.34.0.dev0") + +logger = get_logger(__name__) + +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux Kontext DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import FluxKontextPipeline +import torch +pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-kontextflux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--vae_encode_mode", + type=str, + default="mode", + choices=["sample", "mode"], + help="VAE encoding mode.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + repeats=1, + center_crop=False, + buckets=None, + ): + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + image = train_resize(image) + if args.center_crop: + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, self.size) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + image = train_flip(image) + image = train_transforms(image) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, max_sequence_length): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, + ) + + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + pipeline = FluxKontextPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["text_encoder"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxKontextPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] + else: + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + buckets=buckets, + repeats=args.repeats, + center_crop=args.center_crop, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_two = tokenize_prompt( + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + ) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + ) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + ( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one]) + with accelerator.accumulate(models_to_accumulate): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + else: + elems_to_repeat = len(prompts) + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=args.instance_prompt, + ) + + # Convert images to latent space + if args.cache_latents: + if args.vae_encode_mode == "sample": + model_input = latents_cache[step].sample() + else: + model_input = latents_cache[step].mode() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + if args.vae_encode_mode == "sample": + model_input = vae.encode(pixel_values).latent_dist.sample() + else: + model_input = vae.encode(pixel_values).latent_dist.mode() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) + + latent_image_ids = FluxKontextPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2] // 2, + model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxKontextPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if unwrap_model(transformer).config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxKontextPipeline._unpack_latents( + model_pred, + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + if args.train_text_encoder + else transformer.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if not args.train_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) + pipeline = FluxKontextPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + ) + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two + free_memory() + + images = None + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + modules_to_save = {} + transformer = unwrap_model(transformer) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + modules_to_save["text_encoder"] = text_encoder_one + else: + text_encoder_lora_layers = None + + FluxKontextPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + # Final inference + # Load previous pipeline + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + pipeline = FluxKontextPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + del pipeline + free_memory() + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + images = None + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index bc30411d87..755ff81883 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -3,6 +3,8 @@ import copy import gc import math import random +import re +import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -316,6 +318,46 @@ def free_memory(): torch.xpu.empty_cache() +def parse_buckets_string(buckets_str): + """Parses a string defining buckets into a list of (height, width) tuples.""" + if not buckets_str: + raise ValueError("Bucket string cannot be empty.") + + bucket_pairs = buckets_str.strip().split(";") + parsed_buckets = [] + for pair_str in bucket_pairs: + match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) + if not match: + raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") + try: + height = int(match.group(1)) + width = int(match.group(2)) + if height <= 0 or width <= 0: + raise ValueError("Bucket dimensions must be positive integers.") + if height % 8 != 0 or width % 8 != 0: + warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") + parsed_buckets.append((height, width)) + except ValueError as e: + raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e + + if not parsed_buckets: + raise ValueError("No valid buckets found in the provided string.") + + return parsed_buckets + + +def find_nearest_bucket(h, w, bucket_options): + """Finds the closes bucket to the given height and width.""" + min_metric = float("inf") + best_bucket_idx = None + for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: + min_metric = metric + best_bucket_idx = bucket_idx + return best_bucket_idx + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From d7dd924ece56cddf261cd8b9dd901cbfa594c62c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 27 Jun 2025 04:33:44 +0530 Subject: [PATCH 10/15] Kontext fixes (#11815) fix --- docs/source/en/api/pipelines/flux.md | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_kontext.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 1c41b85ca9..ca39d71814 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -302,11 +302,11 @@ image.save("flux-kontext.png") Flux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code: ```python -from flux.safety import PixtralIntegrity +from flux.content_filters import PixtralContentFilter # ... pipeline invocation to generate images -integrity_checker = PixtralIntegrity(torch.device("cuda")) +integrity_checker = PixtralContentFilter(torch.device("cuda")) image_ = np.array(image) / 255.0 image_ = 2 * image_ - 1 image_ = torch.from_numpy(image_).to("cuda", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 2427e342a9..e90aa6204a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -65,8 +65,10 @@ EXAMPLE_DOC_STRING = """ ... ) >>> pipe.to("cuda") - >>> image = load_image("inputs/yarn-art-pikachu.png").convert("RGB") - >>> prompt = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" >>> image = pipe( ... image=image, ... prompt=prompt, From 21543de571f77cf14b850710529891dbc50dc83c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 27 Jun 2025 15:57:55 +0530 Subject: [PATCH 11/15] remove syncs before denoising in Kontext (#11818) --- src/diffusers/pipelines/flux/pipeline_flux.py | 2 ++ src/diffusers/pipelines/flux/pipeline_flux_kontext.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index bdee2ead48..4c83ae7405 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -898,6 +898,8 @@ class FluxPipeline( ) # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index e90aa6204a..07b9b895a4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -1043,6 +1043,9 @@ class FluxKontextPipeline( ) # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: From e8e44a510c152fff17e3f1bba036d635776b5b9f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 27 Jun 2025 16:33:43 +0530 Subject: [PATCH 12/15] [CI] disable onnx, mps, flax from the CI (#11803) * disable onnx, mps, flax * remove --- .github/workflows/build_docker_images.yml | 4 - .github/workflows/nightly_tests.yml | 104 +--------------------- .github/workflows/pr_tests.yml | 14 --- .github/workflows/push_tests.yml | 96 -------------------- .github/workflows/push_tests_fast.yml | 28 ------ .github/workflows/push_tests_mps.yml | 7 +- .github/workflows/release_tests_fast.yml | 95 -------------------- 7 files changed, 3 insertions(+), 345 deletions(-) diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index 838f241ddc..583853c6d6 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -75,10 +75,6 @@ jobs: - diffusers-pytorch-cuda - diffusers-pytorch-xformers-cuda - diffusers-pytorch-minimum-cuda - - diffusers-flax-cpu - - diffusers-flax-tpu - - diffusers-onnxruntime-cpu - - diffusers-onnxruntime-cuda - diffusers-doc-builder steps: diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 5476616704..16e1a70b84 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -321,55 +321,6 @@ jobs: name: torch_minimum_version_cuda_test_reports path: reports - run_nightly_onnx_tests: - name: Nightly ONNXRuntime CUDA tests on Ubuntu - runs-on: - group: aws-g4dn-2xlarge - container: - image: diffusers/diffusers-onnxruntime-cuda - options: --gpus 0 --shm-size "16gb" --ipc host - - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: NVIDIA-SMI - run: nvidia-smi - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - python -m uv pip install pytest-reportlog - - name: Environment - run: python utils/print_env.py - - - name: Run Nightly ONNXRuntime CUDA tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_onnx_cuda \ - --report-log=tests_onnx_cuda.log \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_onnx_cuda_stats.txt - cat reports/tests_onnx_cuda_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: tests_onnx_cuda_reports - path: reports - run_nightly_quantization_tests: name: Torch quantization nightly tests strategy: @@ -485,57 +436,6 @@ jobs: name: torch_cuda_pipeline_level_quant_reports path: reports - run_flax_tpu_tests: - name: Nightly Flax TPU Tests - runs-on: - group: gcp-ct5lp-hightpu-8t - if: github.event_name == 'schedule' - - container: - image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache - defaults: - run: - shell: bash - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - python -m uv pip install pytest-reportlog - - - name: Environment - run: python utils/print_env.py - - - name: Run nightly Flax TPU tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 0 \ - -s -v -k "Flax" \ - --make-reports=tests_flax_tpu \ - --report-log=tests_flax_tpu.log \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_flax_tpu_stats.txt - cat reports/tests_flax_tpu_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: flax_tpu_test_reports - path: reports - generate_consolidated_report: name: Generate Consolidated Test Report needs: [ @@ -545,9 +445,9 @@ jobs: run_big_gpu_torch_tests, run_nightly_quantization_tests, run_nightly_pipeline_level_quantization_tests, - run_nightly_onnx_tests, + # run_nightly_onnx_tests, torch_minimum_version_cuda_tests, - run_flax_tpu_tests + # run_flax_tpu_tests ] if: always() runs-on: diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index a0bf1e79e8..34a344528e 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -87,11 +87,6 @@ jobs: runner: aws-general-8-plus image: diffusers/diffusers-pytorch-cpu report: torch_cpu_models_schedulers - - name: Fast Flax CPU tests - framework: flax - runner: aws-general-8-plus - image: diffusers/diffusers-flax-cpu - report: flax_cpu - name: PyTorch Example CPU tests framework: pytorch_examples runner: aws-general-8-plus @@ -147,15 +142,6 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests/models tests/schedulers tests/others - - name: Run fast Flax TPU tests - if: ${{ matrix.config.framework == 'flax' }} - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Flax" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests - - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 7cab08b44f..007770c8ed 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -159,102 +159,6 @@ jobs: name: torch_cuda_test_reports_${{ matrix.module }} path: reports - flax_tpu_tests: - name: Flax TPU Tests - runs-on: - group: gcp-ct5lp-hightpu-8t - container: - image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache - defaults: - run: - shell: bash - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - - - name: Environment - run: | - python utils/print_env.py - - - name: Run Flax TPU tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 0 \ - -s -v -k "Flax" \ - --make-reports=tests_flax_tpu \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_flax_tpu_stats.txt - cat reports/tests_flax_tpu_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: flax_tpu_test_reports - path: reports - - onnx_cuda_tests: - name: ONNX CUDA Tests - runs-on: - group: aws-g4dn-2xlarge - container: - image: diffusers/diffusers-onnxruntime-cuda - options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0 - defaults: - run: - shell: bash - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - - - name: Environment - run: | - python utils/print_env.py - - - name: Run ONNXRuntime CUDA tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_onnx_cuda \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_onnx_cuda_stats.txt - cat reports/tests_onnx_cuda_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: onnx_cuda_test_reports - path: reports - run_torch_compile_tests: name: PyTorch Compile CUDA tests diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index e8a73446de..e274cb0218 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -33,16 +33,6 @@ jobs: runner: aws-general-8-plus image: diffusers/diffusers-pytorch-cpu report: torch_cpu - - name: Fast Flax CPU tests on Ubuntu - framework: flax - runner: aws-general-8-plus - image: diffusers/diffusers-flax-cpu - report: flax_cpu - - name: Fast ONNXRuntime CPU tests on Ubuntu - framework: onnxruntime - runner: aws-general-8-plus - image: diffusers/diffusers-onnxruntime-cpu - report: onnx_cpu - name: PyTorch Example CPU tests on Ubuntu framework: pytorch_examples runner: aws-general-8-plus @@ -87,24 +77,6 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests/ - - name: Run fast Flax TPU tests - if: ${{ matrix.config.framework == 'flax' }} - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Flax" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests/ - - - name: Run fast ONNXRuntime CPU tests - if: ${{ matrix.config.framework == 'onnxruntime' }} - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests/ - - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 5fd3b78be7..eb6c0da225 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -1,12 +1,7 @@ name: Fast mps tests on main on: - push: - branches: - - main - paths: - - "src/diffusers/**.py" - - "tests/**.py" + workflow_dispatch: env: DIFFUSERS_IS_CI: yes diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml index a464381ba4..e5d3282049 100644 --- a/.github/workflows/release_tests_fast.yml +++ b/.github/workflows/release_tests_fast.yml @@ -213,101 +213,6 @@ jobs: with: name: torch_minimum_version_cuda_test_reports path: reports - - flax_tpu_tests: - name: Flax TPU Tests - runs-on: docker-tpu - container: - image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged - defaults: - run: - shell: bash - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - - - name: Environment - run: | - python utils/print_env.py - - - name: Run slow Flax TPU tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 0 \ - -s -v -k "Flax" \ - --make-reports=tests_flax_tpu \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_flax_tpu_stats.txt - cat reports/tests_flax_tpu_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: flax_tpu_test_reports - path: reports - - onnx_cuda_tests: - name: ONNX CUDA Tests - runs-on: - group: aws-g4dn-2xlarge - container: - image: diffusers/diffusers-onnxruntime-cuda - options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0 - defaults: - run: - shell: bash - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install -e [quality,test] - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - - - name: Environment - run: | - python utils/print_env.py - - - name: Run slow ONNXRuntime CUDA tests - env: - HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} - run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_onnx_cuda \ - tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: | - cat reports/tests_onnx_cuda_stats.txt - cat reports/tests_onnx_cuda_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: onnx_cuda_test_reports - path: reports run_torch_compile_tests: name: PyTorch Compile CUDA tests From cdaf84a708eadf17d731657f4be3fa39d09a12c0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 27 Jun 2025 18:31:57 +0530 Subject: [PATCH 13/15] TorchAO compile + offloading tests (#11697) * update * update * update * update * update * user property instead --- tests/quantization/bnb/test_4bit.py | 26 ++++++---- tests/quantization/bnb/test_mixed_int8.py | 18 ++++--- .../quantization/test_torch_compile_utils.py | 13 +++-- tests/quantization/torchao/test_torchao.py | 51 +++++++++++++++++++ 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bdb8920a39..c5497d1c8d 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -866,15 +866,17 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests): @require_torch_version_greater("2.7.1") class Bnb4BitCompileTests(QuantCompileTests): - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={ - "load_in_4bit": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_compute_dtype": torch.bfloat16, - }, - components_to_quantize=["transformer", "text_encoder_2"], - ) + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -883,5 +885,7 @@ class Bnb4BitCompileTests(QuantCompileTests): def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) - def test_torch_compile_with_group_offload(self): - super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) + def test_torch_compile_with_group_offload_leaf(self): + super()._test_torch_compile_with_group_offload_leaf( + quantization_config=self.quantization_config, use_stream=True + ) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index d048b0b7db..383cdd6849 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -831,11 +831,13 @@ class BaseBnb8bitSerializationTests(Base8bitTests): @require_torch_version_greater_equal("2.6.0") class Bnb8BitCompileTests(QuantCompileTests): - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={"load_in_8bit": True}, - components_to_quantize=["transformer", "text_encoder_2"], - ) + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ) def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -847,7 +849,7 @@ class Bnb8BitCompileTests(QuantCompileTests): ) @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") - def test_torch_compile_with_group_offload(self): - super()._test_torch_compile_with_group_offload( - quantization_config=self.quantization_config, torch_dtype=torch.float16 + def test_torch_compile_with_group_offload_leaf(self): + super()._test_torch_compile_with_group_offload_leaf( + quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index ba870ba733..99bb8980ef 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -24,7 +24,11 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu @require_torch_gpu @slow class QuantCompileTests(unittest.TestCase): - quantization_config = None + @property + def quantization_config(self): + raise NotImplementedError( + "This property should be implemented in the subclass to return the appropriate quantization config." + ) def setUp(self): super().setUp() @@ -64,7 +68,9 @@ class QuantCompileTests(unittest.TestCase): # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) - def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + def _test_torch_compile_with_group_offload_leaf( + self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False + ): torch._dynamo.config.cache_size_limit = 10000 pipe = self._init_pipeline(quantization_config, torch_dtype) @@ -72,8 +78,7 @@ class QuantCompileTests(unittest.TestCase): "onload_device": torch.device("cuda"), "offload_device": torch.device("cpu"), "offload_type": "leaf_level", - "use_stream": True, - "non_blocking": True, + "use_stream": use_stream, } pipe.transformer.enable_group_offload(**group_offload_kwargs) pipe.transformer.compile() diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0741c7f87c..c4cfc8eb87 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -19,6 +19,7 @@ import unittest from typing import List import numpy as np +from parameterized import parameterized from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import ( @@ -29,6 +30,7 @@ from diffusers import ( TorchAoConfig, ) from diffusers.models.attention_processor import Attention +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils.testing_utils import ( backend_empty_cache, backend_synchronize, @@ -44,6 +46,8 @@ from diffusers.utils.testing_utils import ( torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + enable_full_determinism() @@ -625,6 +629,53 @@ class TorchAoSerializationTest(unittest.TestCase): self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoCompileTest(QuantCompileTests): + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": TorchAoConfig(quant_type="int8_weight_only"), + }, + ) + + def test_torch_compile(self): + super()._test_torch_compile(quantization_config=self.quantization_config) + + @unittest.skip( + "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " + "when compiling." + ) + def test_torch_compile_with_cpu_offload(self): + # RuntimeError: _apply(): Couldn't swap Linear.weight + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + @unittest.skip( + """ + For `use_stream=False`: + - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation + is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure. + For `use_stream=True`: + Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. + """ + ) + @parameterized.expand([False, True]) + def test_torch_compile_with_group_offload_leaf(self): + # For use_stream=False: + # If we run group offloading without compilation, we will see: + # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. + # When running with compilation, the error ends up being different: + # Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16, + # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu') + # Looks like something that will have to be looked into upstream. + # for linear layers, weight.tensor_impl shows cuda... but: + # weight.tensor_impl.{data,scale,zero_point}.device will be cpu + + # For use_stream=True: + # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=, types=(,), arg_types=(,), kwarg_types={} + super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) + + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator From 76ec3d1fee95f30196452027018cb40b977d376c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 27 Jun 2025 23:20:53 +0530 Subject: [PATCH 14/15] Support dynamically loading/unloading loras with group offloading (#11804) * update * add test * address review comments * update * fixes * change decorator order to fix tests * try fix * fight tests --- src/diffusers/hooks/group_offloading.py | 291 +++++++++++------------ src/diffusers/loaders/lora_base.py | 47 ++-- src/diffusers/loaders/peft.py | 11 +- src/diffusers/loaders/unet.py | 19 +- tests/lora/test_lora_layers_cogvideox.py | 9 + tests/lora/test_lora_layers_cogview4.py | 16 +- tests/lora/utils.py | 71 ++++++ 7 files changed, 289 insertions(+), 175 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 45fee35ef3..ac25dd061b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -14,6 +14,8 @@ import os from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union import safetensors.torch @@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = ( # fmt: on +class GroupOffloadingType(str, Enum): + BLOCK_LEVEL = "block_level" + LEAF_LEVEL = "leaf_level" + + +@dataclass +class GroupOffloadingConfig: + onload_device: torch.device + offload_device: torch.device + offload_type: GroupOffloadingType + non_blocking: bool + record_stream: bool + low_cpu_mem_usage: bool + num_blocks_per_group: Optional[int] = None + offload_to_disk_path: Optional[str] = None + stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + + class ModuleGroup: def __init__( self, @@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: + def __init__( + self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig + ) -> None: self.group = group self.next_group = next_group + self.config = config def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -436,7 +459,7 @@ def apply_group_offloading( module: torch.nn.Module, onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), - offload_type: str = "block_level", + offload_type: Union[str, GroupOffloadingType] = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -478,7 +501,7 @@ def apply_group_offloading( The device to which the group of modules are onloaded. offload_device (`torch.device`, defaults to `torch.device("cpu")`): The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. - offload_type (`str`, defaults to "block_level"): + offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". offload_to_disk_path (`str`, *optional*, defaults to `None`): @@ -521,6 +544,8 @@ def apply_group_offloading( ``` """ + offload_type = GroupOffloadingType(offload_type) + stream = None if use_stream: if torch.cuda.is_available(): @@ -532,84 +557,45 @@ def apply_group_offloading( if not use_stream and record_stream: raise ValueError("`record_stream` cannot be True when `use_stream=False`.") + if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: + raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) - if offload_type == "block_level": - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") + config = GroupOffloadingConfig( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, + ) + _apply_group_offloading(module, config) - _apply_group_offloading_block_level( - module=module, - num_blocks_per_group=num_blocks_per_group, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level( - module=module, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + +def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: + if config.offload_type == GroupOffloadingType.BLOCK_LEVEL: + _apply_group_offloading_block_level(module, config) + elif config.offload_type == GroupOffloadingType.LEAF_LEVEL: + _apply_group_offloading_leaf_level(module, config) else: - raise ValueError(f"Unsupported offload_type: {offload_type}") + assert False -def _apply_group_offloading_block_level( - module: torch.nn.Module, - num_blocks_per_group: int, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - offload_to_disk_path: Optional[str] = None, -) -> None: +def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - offload_to_disk_path (`str`, *optional*, defaults to `None`): - The path to the directory where parameters will be offloaded. Setting this option can be useful in limited - RAM environment settings where a reasonable speed-memory trade-off is desired. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream`or `torch.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ - if stream is not None and num_blocks_per_group != 1: + + if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( - f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1." + f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) - num_blocks_per_group = 1 + config.num_blocks_per_group = 1 # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -621,19 +607,19 @@ def _apply_group_offloading_block_level( modules_with_group_offloading.add(name) continue - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = submodule[i : i + config.num_blocks_per_group] group = ModuleGroup( modules=current_modules, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) matched_module_groups.append(group) @@ -643,7 +629,7 @@ def _apply_group_offloading_block_level( # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, None) + _apply_group_offloading_hook(group_module, group, None, config=config) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -658,9 +644,9 @@ def _apply_group_offloading_block_level( unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( modules=unmatched_modules, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -670,54 +656,19 @@ def _apply_group_offloading_block_level( record_stream=False, onload_self=True, ) - if stream is None: - _apply_group_offloading_hook(module, unmatched_group, None) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, None, config=config) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) -def _apply_group_offloading_leaf_level( - module: torch.nn.Module, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - offload_to_disk_path: Optional[str] = None, -) -> None: +def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory requirements. However, it can be slower compared to other offloading methods due to the excessive number of device synchronizations. When using devices that support streams to overlap data transfer and computation, this method can reduce memory usage without any performance degradation. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - offload_to_disk_path (`str`, *optional*, defaults to `None`): - The path to the directory where parameters will be offloaded. Setting this option can be useful in limited - RAM environment settings where a reasonable speed-memory trade-off is desired. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream` or `torch.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ - # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): @@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level( continue group = ModuleGroup( modules=[submodule], - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_group_offloading_hook(submodule, group, None) + _apply_group_offloading_hook(submodule, group, None, config=config) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level( parameters = parent_to_parameters.get(name, []) buffers = parent_to_buffers.get(name, []) parent_module = module_dict[name] - assert getattr(parent_module, "_diffusers_hook", None) is None group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, offload_leader=parent_module, onload_leader=parent_module, - offload_to_disk_path=offload_to_disk_path, + offload_to_disk_path=config.offload_to_disk_path, parameters=parameters, buffers=buffers, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_group_offloading_hook(parent_module, group, None) + _apply_group_offloading_hook(parent_module, group, None, config=config) - if stream is not None: + if config.stream is not None: # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # execution order and apply prefetching in the correct order. unmatched_group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, @@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level( non_blocking=False, stream=None, record_stream=False, - low_cpu_mem_usage=low_cpu_mem_usage, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_lazy_group_offloading_hook(module, unmatched_group, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) @@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() @@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn ) -def _is_group_offload_enabled(module: torch.nn.Module) -> bool: +def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]: for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return True - return False + if hasattr(submodule, "_diffusers_hook"): + group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) + if group_offloading_hook is not None: + return group_offloading_hook + return None + + +def _is_group_offload_enabled(module: torch.nn.Module) -> bool: + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + return top_level_group_offload_hook is not None def _get_group_onload_device(module: torch.nn.Module) -> torch.device: - for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + if top_level_group_offload_hook is not None: + return top_level_group_offload_hook.config.onload_device raise ValueError("Group offloading is not enabled for the provided module.") + + +def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: + r""" + Removes the group offloading hook from the module and re-applies it. This is useful when the module has been + modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place + modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly. + + In this implementation, we make an assumption that group offloading has only been applied at the top-level module, + and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the + case where user has applied group offloading at multiple levels, this function will not work as expected. + + There is some performance penalty associated with doing this when non-default streams are used, because we need to + retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`. + """ + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + + if top_level_group_offload_hook is None: + return + + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.remove_hook(_GROUP_OFFLOADING, recurse=True) + registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) + registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) + + _apply_group_offloading(module, top_level_group_offload_hook.config) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e6941a521d..562a21dbbb 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,6 +25,7 @@ import torch.nn as nn from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE +from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, @@ -391,7 +392,9 @@ def _load_lora_into_text_encoder( adapter_name = get_adapter_name(text_encoder) # if prefix is not None and not state_dict: @@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline): Returns: tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) + if not isinstance(component, nn.Module): + continue + is_group_offload = is_group_offload or _is_group_offload_enabled(component) + if not hasattr(component, "_hf_hook"): + continue + is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) + is_sequential_cpu_offload = is_sequential_cpu_offload or ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - if is_sequential_cpu_offload or is_model_cpu_offload: - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + if is_sequential_cpu_offload or is_model_cpu_offload: + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + for _, component in _pipeline.components.items(): + if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): + continue + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - return (is_model_cpu_offload, is_sequential_cpu_offload) + return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) class LoraBaseMixin: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3436230713..3670243de8 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union import safetensors import torch +from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, @@ -256,7 +257,9 @@ class PeftAdapterMixin: # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error. - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline + ) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -347,6 +350,10 @@ class PeftAdapterMixin: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() + elif is_group_offload: + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> if prefix is not None and not state_dict: @@ -687,6 +694,8 @@ class PeftAdapterMixin: if hasattr(self, "peft_config"): del self.peft_config + _maybe_remove_and_reapply_group_offloading(self) + def disable_lora(self): """ Disables the active LoRA layers of the underlying model. diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 68be841191..c9b6a7d7d8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -22,6 +22,7 @@ import torch import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args +from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, @@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin: is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False if is_lora: deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." @@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin: if is_custom_diffusion: attn_processors = self._process_custom_diffusion(state_dict=state_dict) elif is_lora: - is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora( state_dict=state_dict, unet_identifier_key=self.unet_name, network_alphas=network_alphas, @@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin: # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`. if is_custom_diffusion and _pipeline is not None: - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline=_pipeline + ) # only custom diffusion needs to set attn processors self.set_attn_processor(attn_processors) @@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() + elif is_group_offload: + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> def _process_custom_diffusion(self, state_dict): @@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin: is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict if len(state_dict_to_be_used) > 0: @@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin: # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline + ) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin: if warn_msg: logger.warning(warn_msg) - return is_model_cpu_offload, is_sequential_cpu_offload + return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index bd7b33445c..565d6db697 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -16,6 +16,7 @@ import sys import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( @@ -28,6 +29,7 @@ from diffusers import ( from diffusers.utils.testing_utils import ( floats_tensor, require_peft_backend, + require_torch_accelerator, ) @@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_lora_scale_kwargs_match_fusion(self): super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) + @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. + # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 + super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 23573bcb21..b7367d9b09 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -18,10 +18,17 @@ import unittest import numpy as np import torch +from parameterized import parameterized from transformers import AutoTokenizer, GlmModel from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + require_peft_backend, + require_torch_accelerator, + skip_mps, + torch_device, +) sys.path.append(".") @@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "Loading from saved checkpoints should give same results.", ) + @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. + # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 + super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + @unittest.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 93dc4a2c37..acd6f5f343 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import ( is_torch_version, require_peft_backend, require_peft_version_greater, + require_torch_accelerator, require_transformers_version_greater, skip_mps, torch_device, @@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests: pipe.load_lora_weights(tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) + + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook + + onload_device = torch_device + offload_device = torch.device("cpu") + + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + + components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Test group offloading with load_lora_weights + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_1 is not None) + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # Test group offloading after removing the lora + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_2 is not None) + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + + # Add the lora again and check if group offloading works + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_3 is not None) + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) + + @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + return + self._test_group_offloading_inference_denoiser(offload_type, use_stream) From 05e7a854d0a5661f5b433f6dd5954c224b104f0b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 28 Jun 2025 12:00:42 +0530 Subject: [PATCH 15/15] [lora] fix: lora unloading behvaiour (#11822) * fix: lora unloading behvaiour * fix * update --- src/diffusers/loaders/peft.py | 2 ++ tests/lora/utils.py | 65 ++++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3670243de8..211f9da619 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -693,6 +693,8 @@ class PeftAdapterMixin: recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config + if hasattr(self, "_hf_peft_config_loaded"): + self._hf_peft_config_loaded = None _maybe_remove_and_reapply_group_offloading(self) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index acd6f5f343..8180f92245 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -291,9 +291,7 @@ class PeftLoraLoaderMixinTests: return modules_to_save - def check_if_adapters_added_correctly( - self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" - ): + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) @@ -345,7 +343,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -428,7 +426,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -484,7 +482,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -522,7 +520,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -554,7 +552,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -589,7 +587,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -640,7 +638,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -691,7 +689,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: @@ -734,7 +732,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -775,7 +773,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -819,7 +817,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) @@ -857,7 +855,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -893,7 +891,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") @@ -1010,7 +1008,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1032,7 +1030,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1759,7 +1757,7 @@ class PeftLoraLoaderMixinTests: output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1850,7 +1848,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) @@ -1937,7 +1935,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -2119,7 +2117,7 @@ class PeftLoraLoaderMixinTests: pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) @@ -2237,7 +2235,7 @@ class PeftLoraLoaderMixinTests: ) pipe = self.pipeline_class(**components) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -2290,7 +2288,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -2309,6 +2307,25 @@ class PeftLoraLoaderMixinTests: np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." ) + def test_lora_unload_add_adapter(self): + """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # unload and then add. + pipe.unload_lora_weights() + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: