From 05e7a854d0a5661f5b433f6dd5954c224b104f0b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 28 Jun 2025 12:00:42 +0530 Subject: [PATCH 01/14] [lora] fix: lora unloading behvaiour (#11822) * fix: lora unloading behvaiour * fix * update --- src/diffusers/loaders/peft.py | 2 ++ tests/lora/utils.py | 65 ++++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3670243de8..211f9da619 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -693,6 +693,8 @@ class PeftAdapterMixin: recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config + if hasattr(self, "_hf_peft_config_loaded"): + self._hf_peft_config_loaded = None _maybe_remove_and_reapply_group_offloading(self) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index acd6f5f343..8180f92245 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -291,9 +291,7 @@ class PeftLoraLoaderMixinTests: return modules_to_save - def check_if_adapters_added_correctly( - self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" - ): + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) @@ -345,7 +343,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -428,7 +426,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -484,7 +482,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -522,7 +520,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -554,7 +552,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -589,7 +587,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -640,7 +638,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -691,7 +689,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: @@ -734,7 +732,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -775,7 +773,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( @@ -819,7 +817,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) @@ -857,7 +855,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -893,7 +891,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") @@ -1010,7 +1008,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1032,7 +1030,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -1759,7 +1757,7 @@ class PeftLoraLoaderMixinTests: output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1850,7 +1848,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) @@ -1937,7 +1935,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -2119,7 +2117,7 @@ class PeftLoraLoaderMixinTests: pipe = pipe.to(torch_device, dtype=compute_dtype) pipe.set_progress_bar_config(disable=None) - pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) + pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) @@ -2237,7 +2235,7 @@ class PeftLoraLoaderMixinTests: ) pipe = self.pipeline_class(**components) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -2290,7 +2288,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.check_if_adapters_added_correctly( + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -2309,6 +2307,25 @@ class PeftLoraLoaderMixinTests: np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." ) + def test_lora_unload_add_adapter(self): + """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # unload and then add. + pipe.unload_lora_weights() + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: From bc34fa8386a3a63da69d5c92e7105532bb66faa3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 30 Jun 2025 20:08:53 +0530 Subject: [PATCH 02/14] [lora]feat: use exclude modules to loraconfig. (#11806) * feat: use exclude modules to loraconfig. * version-guard. * tests and version guard. * remove print. * describe the test * more detailed warning message + shift to debug * update * update * update * remove test --- src/diffusers/loaders/peft.py | 13 ++++-- src/diffusers/utils/peft_utils.py | 57 +++++++++++++++++++++---- tests/lora/test_lora_layers_wan.py | 6 ++- tests/lora/utils.py | 67 ++++++++++++++++++++++++++++++ 4 files changed, 131 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 211f9da619..4ade3374d8 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -244,13 +244,20 @@ class PeftAdapterMixin: k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - # create LoraConfig - lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank) - # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) + # =", "0.14.0"): + msg = """ +It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` +version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U +peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue - +https://github.com/huggingface/diffusers/issues/new + """ + logger.debug(msg) + else: + lora_config_kwargs.update({"exclude_modules": exclude_modules}) + return lora_config_kwargs @@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None: def _create_lora_config( - state_dict, - network_alphas, - metadata, - rank_pattern_dict, - is_unet: bool = True, + state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ): from peft import LoraConfig @@ -306,7 +318,12 @@ def _create_lora_config( lora_config_kwargs = metadata else: lora_config_kwargs = get_peft_kwargs( - rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet + rank_pattern_dict, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + is_unet=is_unet, + model_state_dict=model_state_dict, + adapter_name=adapter_name, ) _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) @@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): if warn_msg: logger.warning(warn_msg) + + +def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): + """ + Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the + `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it + doesn't exist in `peft_state_dict`. + """ + if model_state_dict is None: + return + all_modules = set() + string_to_replace = f"{adapter_name}." if adapter_name else "" + + for name in model_state_dict.keys(): + if string_to_replace: + name = name.replace(string_to_replace, "") + if "." in name: + module_name = name.rsplit(".", 1)[0] + all_modules.add(module_name) + + target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} + exclude_modules = list(all_modules - target_modules_set) + + return exclude_modules diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 95ec44b2bf..fe26a56e77 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -24,7 +24,11 @@ from diffusers import ( WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps +from diffusers.utils.testing_utils import ( + floats_tensor, + require_peft_backend, + skip_mps, +) sys.path.append(".") diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 8180f92245..91ca188137 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os import re @@ -291,6 +292,20 @@ class PeftLoraLoaderMixinTests: return modules_to_save + def _get_exclude_modules(self, pipe): + from diffusers.utils.peft_utils import _derive_exclude_modules + + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + denoiser = "unet" if self.unet_kwargs is not None else "transformer" + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] + pipe.unload_lora_weights() + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() + exclude_modules = _derive_exclude_modules( + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" + ) + return exclude_modules + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -2326,6 +2341,58 @@ class PeftLoraLoaderMixinTests: ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + @require_peft_version_greater("0.13.2") + def test_lora_exclude_modules(self): + """ + Test to check if `exclude_modules` works or not. It works in the following way: + we first create a pipeline and insert LoRA config into it. We then derive a `set` + of modules to exclude by investigating its denoiser state dict and denoiser LoRA + state dict. + + We then create a new LoRA config to include the `exclude_modules` and perform tests. + """ + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + # only supported for `denoiser` now + pipe_cp = copy.deepcopy(pipe) + pipe_cp, _ = self.add_adapters_to_pipeline( + pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) + pipe_cp.to("cpu") + del pipe_cp + + denoiser_lora_config.exclude_modules = denoiser_exclude_modules + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), + "LoRA should change outputs.", + ) + self.assertTrue( + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), + "Lora outputs should match.", + ) + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: From 3b079ec3fadfc95240bc1c48ae86de28b72cc9f2 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 30 Jun 2025 16:55:56 +0200 Subject: [PATCH 03/14] ENH: Improve speed of function expanding LoRA scales (#11834) * ENH Improve speed of expanding LoRA scales Resolves #11816 The following call proved to be a bottleneck when setting a lot of LoRA adapters in diffusers: https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/peft.py#L482 This is because we would repeatedly call unet.state_dict(), even though in the standard case, it is not necessary: https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/unet_loader_utils.py#L55 This PR fixes this by deferring this call, so that it is only run when it's necessary, not earlier. * Small fix --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/unet_loader_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py index 274665204d..d5b0e83cbd 100644 --- a/src/diffusers/loaders/unet_loader_utils.py +++ b/src/diffusers/loaders/unet_loader_utils.py @@ -14,6 +14,8 @@ import copy from typing import TYPE_CHECKING, Dict, List, Union +from torch import nn + from ..utils import logging @@ -52,7 +54,7 @@ def _maybe_expand_lora_scales( weight_for_adapter, blocks_with_transformer, transformer_per_block, - unet.state_dict(), + model=unet, default_scale=default_scale, ) for weight_for_adapter in weight_scales @@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter( scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int], - state_dict: None, + model: nn.Module, default_scale: float = 1.0, ): """ @@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter( del scales[updown] + state_dict = model.state_dict() for layer in scales.keys(): if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): raise ValueError( From f064b3bf73e479051ed4255d98afad4259a6f012 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Jul 2025 00:37:34 +0530 Subject: [PATCH 04/14] Remove print statement in SCM Scheduler (#11836) remove print --- src/diffusers/schedulers/scheduling_scm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index acff268c9b..63b4a109ff 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin): else: # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float() - print(f"Set timesteps: {self.timesteps}") self._step_index = None self._begin_index = None From 87f83d3dd9247affcc0912175b2eff5f4a56e75a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 1 Jul 2025 09:40:34 +0530 Subject: [PATCH 05/14] [tests] add test for hotswapping + compilation on resolution changes (#11825) * add resolution changes tests to hotswapping test suite. * fixes * docs * explain duck shapes * fix --- .../en/tutorials/using_peft_for_inference.md | 2 + tests/models/test_modeling_common.py | 46 ++++++++++++++++--- .../test_models_transformer_flux.py | 4 ++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index b18977720c..5a382c1c94 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -315,6 +315,8 @@ pipeline.load_lora_weights( > [!TIP] > Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example. +If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details. + There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs. ## Merge diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index dcc7ae16a4..def81ecd64 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1350,7 +1350,6 @@ class ModelTesterMixin: new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin: """ + different_shapes_for_compilation = None + def tearDown(self): # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, # there will be recompilation errors, as torch caches the model when run in the same process. @@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin: - hotswap the second adapter - check that the outputs are correct - optionally compile the model + - optionally check if recompilations happen on different shapes Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is fine. """ + different_shapes = self.different_shapes_for_compilation # create 2 adapters with different ranks and alphas torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin: model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - model = torch.compile(model, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) with torch.inference_mode(): - output0_after = model(**inputs_dict)["sample"] - assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = model(**inputs_dict)["sample"] - assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" @@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin: do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap( + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4552b2e1f5..68b5c02bc0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width) From f3e131046983d2e025144e8a8ac7dfc93f1249eb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 1 Jul 2025 12:36:54 +0800 Subject: [PATCH 06/14] reset deterministic in tearDownClass (#11785) * reset deterministic in tearDownClass Signed-off-by: jiqing-feng * fix deterministic setting Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng Co-authored-by: Sayak Paul --- tests/quantization/bnb/test_4bit.py | 9 ++++++++- tests/quantization/bnb/test_mixed_int8.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index c5497d1c8d..06116cac3a 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase): @classmethod def setUpClass(cls): - torch.use_deterministic_algorithms(True) + cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled() + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(True) + + @classmethod + def tearDownClass(cls): + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(False) def get_dummy_inputs(self): prompt_embeds = load_pt( diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 383cdd6849..2ea4cdfde8 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase): @classmethod def setUpClass(cls): - torch.use_deterministic_algorithms(True) + cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled() + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(True) + + @classmethod + def tearDownClass(cls): + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(False) def get_dummy_inputs(self): prompt_embeds = load_pt( From 3f3f0c16a6418c7c5505c0a33088fddb5bc90317 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Jul 2025 11:13:58 +0530 Subject: [PATCH 07/14] [tests] Fix failing float16 cuda tests (#11835) * update * update --------- Co-authored-by: Sayak Paul --- tests/pipelines/test_pipelines_common.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 69dd79bb56..f87778b260 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1378,7 +1378,6 @@ class PipelineTesterMixin: for component in pipe_fp16.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - pipe_fp16.to(torch_device, torch.float16) pipe_fp16.set_progress_bar_config(disable=None) @@ -1386,17 +1385,20 @@ class PipelineTesterMixin: # Reset generator in case it is used inside dummy inputs if "generator" in inputs: inputs["generator"] = self.get_generator(0) - output = pipe(**inputs)[0] fp16_inputs = self.get_dummy_inputs(torch_device) # Reset generator in case it is used inside dummy inputs if "generator" in fp16_inputs: fp16_inputs["generator"] = self.get_generator(0) - output_fp16 = pipe_fp16(**fp16_inputs)[0] + + if isinstance(output, torch.Tensor): + output = output.cpu() + output_fp16 = output_fp16.cpu() + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) - assert max_diff < 1e-2 + assert max_diff < expected_max_diff @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator From a79c3af6bbda8ba1ca5aa4e7855708fcc9b02238 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Jul 2025 18:02:58 +0530 Subject: [PATCH 08/14] [single file] Cosmos (#11801) * update * update * update docs --- docs/source/en/api/pipelines/cosmos.md | 25 +++ scripts/convert_cosmos_to_diffusers.py | 1 - src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 152 ++++++++++++++++++ .../models/transformers/transformer_cosmos.py | 3 +- 5 files changed, 184 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md index 99deef37e1..dba807c5ce 100644 --- a/docs/source/en/api/pipelines/cosmos.md +++ b/docs/source/en/api/pipelines/cosmos.md @@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) +## Loading original format checkpoints + +Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method. + +```python +import torch +from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel + +model_id = "nvidia/Cosmos-Predict2-2B-Text2Image" +transformer = CosmosTransformer3DModel.from_single_file( + "https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt", + torch_dtype=torch.bfloat16, +).to("cuda") +pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." +negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + +output = pipe( + prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) +).images[0] +output.save("output.png") +``` + ## CosmosTextToWorldPipeline [[autodoc]] CosmosTextToWorldPipeline diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index 0c0426a1ef..6f6563ad64 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = { "mlp.layer1": "ff.net.0.proj", "mlp.layer2": "ff.net.2", "x_embedder.proj.1": "patch_embed.proj", - # "extra_pos_embedder": "learnable_pos_embed", "final_layer.adaln_modulation.1": "norm_out.linear_1", "final_layer.adaln_modulation.2": "norm_out.linear_2", "final_layer.linear": "proj_out", diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 0c6f3cda66..2e99afbd51 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -31,6 +31,7 @@ from .single_file_utils import ( convert_autoencoder_dc_checkpoint_to_diffusers, convert_chroma_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, + convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, @@ -143,6 +144,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers, "default_subfolder": "transformer", }, + "CosmosTransformer3DModel": { + "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d8d183304e..3f81243693 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -127,6 +127,16 @@ CHECKPOINT_KEY_NAMES = { "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", + "cosmos-1.0": [ + "net.x_embedder.proj.1.weight", + "net.blocks.block1.blocks.0.block.attn.to_q.0.weight", + "net.extra_pos_embedder.pos_emb_h", + ], + "cosmos-2.0": [ + "net.x_embedder.proj.1.weight", + "net.blocks.0.self_attn.q_proj.weight", + "net.pos_embedder.dim_spatial_range", + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -193,6 +203,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, + "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"}, + "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"}, + "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"}, + "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"}, + "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"}, + "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"}, + "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"}, + "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"}, } # Use to configure model sample size when original config is provided @@ -704,11 +722,32 @@ def infer_diffusers_model_type(checkpoint): model_type = "wan-t2v-14B" else: model_type = "wan-i2v-14B" + elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: # All Wan models use the same VAE so we can use the same default model repo to fetch the config model_type = "wan-t2v-14B" + elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint: model_type = "hidream" + + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]): + x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape + if x_embedder_shape[1] == 68: + model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B" + elif x_embedder_shape[1] == 72: + model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B" + else: + raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.") + + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]): + x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape + if x_embedder_shape[1] == 68: + model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B" + elif x_embedder_shape[1] == 72: + model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B" + else: + raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.") + else: model_type = "v1" @@ -3479,3 +3518,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") return converted_state_dict + + +def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + def rename_transformer_blocks_(key: str, state_dict): + block_index = int(key.split(".")[1].removeprefix("block")) + new_key = key + old_prefix = f"blocks.block{block_index}" + new_prefix = f"transformer_blocks.{block_index}" + new_key = new_prefix + new_key.removeprefix(old_prefix) + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = { + "t_embedder.1": "time_embed.t_embedder", + "affline_norm": "time_embed.norm", + ".blocks.0.block.attn": ".attn1", + ".blocks.1.block.attn": ".attn2", + ".blocks.2.block": ".ff", + ".blocks.0.adaLN_modulation.1": ".norm1.linear_1", + ".blocks.0.adaLN_modulation.2": ".norm1.linear_2", + ".blocks.1.adaLN_modulation.1": ".norm2.linear_1", + ".blocks.1.adaLN_modulation.2": ".norm2.linear_2", + ".blocks.2.adaLN_modulation.1": ".norm3.linear_1", + ".blocks.2.adaLN_modulation.2": ".norm3.linear_2", + "to_q.0": "to_q", + "to_q.1": "norm_q", + "to_k.0": "to_k", + "to_k.1": "norm_k", + "to_v.0": "to_v", + "layer1": "net.0.proj", + "layer2": "net.2", + "proj.1": "proj", + "x_embedder": "patch_embed", + "extra_pos_embedder": "learnable_pos_embed", + "final_layer.adaLN_modulation.1": "norm_out.linear_1", + "final_layer.adaLN_modulation.2": "norm_out.linear_2", + "final_layer.linear": "proj_out", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = { + "blocks.block": rename_transformer_blocks_, + "logvar.0.freqs": remove_keys_, + "logvar.0.phases": remove_keys_, + "logvar.1.weight": remove_keys_, + "pos_embedder.seq": remove_keys_, + } + + TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = { + "t_embedder.1": "time_embed.t_embedder", + "t_embedding_norm": "time_embed.norm", + "blocks": "transformer_blocks", + "adaln_modulation_self_attn.1": "norm1.linear_1", + "adaln_modulation_self_attn.2": "norm1.linear_2", + "adaln_modulation_cross_attn.1": "norm2.linear_1", + "adaln_modulation_cross_attn.2": "norm2.linear_2", + "adaln_modulation_mlp.1": "norm3.linear_1", + "adaln_modulation_mlp.2": "norm3.linear_2", + "self_attn": "attn1", + "cross_attn": "attn2", + "q_proj": "to_q", + "k_proj": "to_k", + "v_proj": "to_v", + "output_proj": "to_out.0", + "q_norm": "norm_q", + "k_norm": "norm_k", + "mlp.layer1": "ff.net.0.proj", + "mlp.layer2": "ff.net.2", + "x_embedder.proj.1": "patch_embed.proj", + "final_layer.adaln_modulation.1": "norm_out.linear_1", + "final_layer.adaln_modulation.2": "norm_out.linear_2", + "final_layer.linear": "proj_out", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = { + "accum_video_sample_counter": remove_keys_, + "accum_image_sample_counter": remove_keys_, + "accum_iteration": remove_keys_, + "accum_train_in_hours": remove_keys_, + "pos_embedder.seq": remove_keys_, + "pos_embedder.dim_spatial_range": remove_keys_, + "pos_embedder.dim_temporal_range": remove_keys_, + "_extra_state": remove_keys_, + } + + PREFIX_KEY = "net." + if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint: + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 + else: + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 + + state_dict_keys = list(converted_state_dict.keys()) + for key in state_dict_keys: + new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = new_key.removeprefix(PREFIX_KEY) + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + state_dict_keys = list(converted_state_dict.keys()) + for key in state_dict_keys: + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 6c312b7a5a..3a6cb1ce6e 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torchvision_available from ..attention import FeedForward from ..attention_processor import Attention @@ -377,7 +378,7 @@ class CosmosLearnablePositionalEmbed(nn.Module): return (emb / norm).type_as(hidden_states) -class CosmosTransformer3DModel(ModelMixin, ConfigMixin): +class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos). From 470458623e8a9fd0d546a2e15808443b45fe89e4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 1 Jul 2025 21:23:27 +0530 Subject: [PATCH 09/14] [docs] fix single_file example. (#11847) fix single_file example. --- docs/source/en/api/pipelines/wan.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 18b8207e3b..81cd242151 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip ```py # pip install ftfy import torch - from diffusers import WanPipeline, AutoModel + from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan - vae = AutoModel.from_single_file( + vae = AutoencoderKLWan.from_single_file( "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" ) - transformer = AutoModel.from_single_file( + transformer = WanTransformer3DModel.from_single_file( "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors", torch_dtype=torch.bfloat16 ) From 62e847db5ff99a3319ae2f8f84184709316ba01f Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Wed, 2 Jul 2025 02:57:19 +0300 Subject: [PATCH 10/14] Use real-valued instead of complex tensors in Wan2.1 RoPE (#11649) * use real instead of complex tensors in Wan2.1 RoPE * remove the redundant type conversion * unpack rotary_emb * register rotary embedding frequencies as non-persistent buffers * Apply style fixes --------- Co-authored-by: Aryan Co-authored-by: github-actions[bot] --- .../models/transformers/transformer_wan.py | 84 ++++++++++++------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 0ae7f2c00d..5fb71b69f7 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -71,14 +71,22 @@ class WanAttnProcessor2_0: if rotary_emb is not None: - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 - x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) + x1, x2 = x[..., 0], x[..., 1] + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) # I2V task hidden_states_img = None @@ -179,7 +187,11 @@ class WanTimeTextImageEmbedding(nn.Module): class WanRotaryPosEmbed(nn.Module): def __init__( - self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, ): super().__init__() @@ -189,36 +201,52 @@ class WanRotaryPosEmbed(nn.Module): h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim - - freqs = [] freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, ) - freqs.append(freq) - self.freqs = torch.cat(freqs, dim=1) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs = self.freqs.to(hidden_states.device) - freqs = freqs.split_with_sizes( - [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ], - dim=1, - ) + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] - freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) - freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) - freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - return freqs + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + + return freqs_cos, freqs_sin class WanTransformerBlock(nn.Module): From d31b8cea3e2cf15154255364b1ee9c544c4ae371 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 1 Jul 2025 17:00:20 -0700 Subject: [PATCH 11/14] [docs] Batch generation (#11841) * draft * fix * fix * feedback * feedback --- docs/source/en/_toctree.yml | 2 + .../en/using-diffusers/batched_inference.md | 264 ++++++++++++++++++ .../en/using-diffusers/reusing_seeds.md | 50 ---- 3 files changed, 266 insertions(+), 50 deletions(-) create mode 100644 docs/source/en/using-diffusers/batched_inference.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 283efeef72..770093438e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -64,6 +64,8 @@ title: Overview - local: using-diffusers/create_a_server title: Create a server + - local: using-diffusers/batched_inference + title: Batch inference - local: training/distributed_inference title: Distributed inference - local: using-diffusers/scheduler_features diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md new file mode 100644 index 0000000000..b5e55c27ca --- /dev/null +++ b/docs/source/en/using-diffusers/batched_inference.md @@ -0,0 +1,264 @@ + + +# Batch inference + +Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU. + +The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches. + + + + +For text-to-image, pass a list of prompts to the pipeline. + +```py +import torch +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +prompts = [ + "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed", + "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain", + "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics" +] + +images = pipeline( + prompt=prompts, +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + +To generate multiple variations of one prompt, use the `num_images_per_prompt` argument. + +```py +import torch +import matplotlib.pyplot as plt +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +images = pipeline( + prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics", + num_images_per_prompt=4 +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + +Combine both approaches to generate different variations of different prompts. + +```py +images = pipeline( + prompt=prompts, + num_images_per_prompt=2, +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + + + + +For image-to-image, pass a list of input images and prompts to the pipeline. + +```py +import torch +from diffusers.utils import load_image +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +input_images = [ + load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"), + load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"), + load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png") +] + +prompts = [ + "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed", + "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain", + "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics" +] + +images = pipeline( + prompt=prompts, + image=input_images, + guidance_scale=8.0, + strength=0.5 +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + +To generate multiple variations of one prompt, use the `num_images_per_prompt` argument. + +```py +import torch +import matplotlib.pyplot as plt +from diffusers.utils import load_image +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png") + +images = pipeline( + prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics", + image=input_image, + num_images_per_prompt=4 +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + +Combine both approaches to generate different variations of different prompts. + +```py +input_images = [ + load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"), + load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png") +] + +prompts = [ + "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain", + "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics" +] + +images = pipeline( + prompt=prompts, + image=input_images, + num_images_per_prompt=2, +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + + + + +## Deterministic generation + +Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it. + +Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. + +Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch. + +```py +generator = [torch.Generator(device="cuda").manual_seed(0)] * 3 +``` + +Pass the `generator` to the pipeline. + +```py +import torch +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)] +prompts = [ + "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed", + "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain", + "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics" +] + +images = pipeline( + prompt=prompts, + generator=generator +).images + +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) +axes = axes.flatten() + +for i, image in enumerate(images): + axes[i].imshow(image) + axes[i].set_title(f"Image {i+1}") + axes[i].axis('off') + +plt.tight_layout() +plt.show() +``` + +You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt. \ No newline at end of file diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md index 60b8fee754..ac9350f24c 100644 --- a/docs/source/en/using-diffusers/reusing_seeds.md +++ b/docs/source/en/using-diffusers/reusing_seeds.md @@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type=" print("L_inf dist =", abs(result1 - result2).max()) "L_inf dist = tensor(0., device='cuda:0')" ``` - -## Deterministic batch generation - -A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it. - -Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images. - -```py -import torch -from diffusers import DiffusionPipeline -from diffusers.utils import make_image_grid - -pipeline = DiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True -) -pipeline = pipeline.to("cuda") -``` - -Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on. - -> [!WARNING] -> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch. -> -> ```py -> [torch.Generator().manual_seed(seed)] * 4 -> ``` - -```python -generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)] -prompt = "Labrador in the style of Vermeer" -images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0] -make_image_grid(images, rows=2, cols=2) -``` - -
- -
- -Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image. - -```python -prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]] -generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)] -images = pipeline(prompt, generator=generator).images -make_image_grid(images, rows=2, cols=2) -``` - -
- -
From 64a9210315459b8217259792f673106ff0053c13 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 1 Jul 2025 17:02:54 -0700 Subject: [PATCH 12/14] [docs] Deprecated pipelines (#11838) add warning Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/amused.md | 3 +++ docs/source/en/api/pipelines/attend_and_excite.md | 3 +++ docs/source/en/api/pipelines/audioldm.md | 3 +++ docs/source/en/api/pipelines/blip_diffusion.md | 3 +++ docs/source/en/api/pipelines/controlnetxs.md | 3 +++ docs/source/en/api/pipelines/controlnetxs_sdxl.md | 3 +++ docs/source/en/api/pipelines/dance_diffusion.md | 3 +++ docs/source/en/api/pipelines/diffedit.md | 3 +++ docs/source/en/api/pipelines/i2vgenxl.md | 3 +++ docs/source/en/api/pipelines/musicldm.md | 3 +++ docs/source/en/api/pipelines/paint_by_example.md | 3 +++ docs/source/en/api/pipelines/panorama.md | 3 +++ docs/source/en/api/pipelines/pia.md | 3 +++ docs/source/en/api/pipelines/self_attention_guidance.md | 3 +++ docs/source/en/api/pipelines/semantic_stable_diffusion.md | 3 +++ docs/source/en/api/pipelines/stable_diffusion/gligen.md | 3 +++ .../en/api/pipelines/stable_diffusion/k_diffusion.md | 3 +++ .../en/api/pipelines/stable_diffusion/ldm3d_diffusion.md | 3 +++ .../pipelines/stable_diffusion/stable_diffusion_safe.md | 3 +++ docs/source/en/api/pipelines/text_to_video.md | 7 ++----- docs/source/en/api/pipelines/text_to_video_zero.md | 3 +++ docs/source/en/api/pipelines/unclip.md | 3 +++ docs/source/en/api/pipelines/unidiffuser.md | 3 +++ docs/source/en/api/pipelines/wuerstchen.md | 3 +++ 24 files changed, 71 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md index eb78c8b704..ad292abca2 100644 --- a/docs/source/en/api/pipelines/amused.md +++ b/docs/source/en/api/pipelines/amused.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # aMUSEd aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen. diff --git a/docs/source/en/api/pipelines/attend_and_excite.md b/docs/source/en/api/pipelines/attend_and_excite.md index ca0aa7af98..b5ce3bb767 100644 --- a/docs/source/en/api/pipelines/attend_and_excite.md +++ b/docs/source/en/api/pipelines/attend_and_excite.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Attend-and-Excite Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation. diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md index a5ef9c4872..6b143d2990 100644 --- a/docs/source/en/api/pipelines/audioldm.md +++ b/docs/source/en/api/pipelines/audioldm.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # AudioLDM AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md index c13288d489..d94281a4a9 100644 --- a/docs/source/en/api/pipelines/blip_diffusion.md +++ b/docs/source/en/api/pipelines/blip_diffusion.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # BLIP-Diffusion BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation. diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md index 2eebcc6b74..aea8cb2e86 100644 --- a/docs/source/en/api/pipelines/controlnetxs.md +++ b/docs/source/en/api/pipelines/controlnetxs.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # ControlNet-XS
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md index 0862a5d798..76937b16c5 100644 --- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md +++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # ControlNet-XS with Stable Diffusion XL ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md index 64a738f17c..5805561e49 100644 --- a/docs/source/en/api/pipelines/dance_diffusion.md +++ b/docs/source/en/api/pipelines/dance_diffusion.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Dance Diffusion [Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans. diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md index 02a76cf589..9734ca2eab 100644 --- a/docs/source/en/api/pipelines/diffedit.md +++ b/docs/source/en/api/pipelines/diffedit.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # DiffEdit [DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord. diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md index eea7eeab19..76a51a6cd5 100644 --- a/docs/source/en/api/pipelines/i2vgenxl.md +++ b/docs/source/en/api/pipelines/i2vgenxl.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # I2VGen-XL [I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou. diff --git a/docs/source/en/api/pipelines/musicldm.md b/docs/source/en/api/pipelines/musicldm.md index 5072bcc4fb..c2297162f7 100644 --- a/docs/source/en/api/pipelines/musicldm.md +++ b/docs/source/en/api/pipelines/musicldm.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # MusicLDM MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov. diff --git a/docs/source/en/api/pipelines/paint_by_example.md b/docs/source/en/api/pipelines/paint_by_example.md index 769156643b..362c26de68 100644 --- a/docs/source/en/api/pipelines/paint_by_example.md +++ b/docs/source/en/api/pipelines/paint_by_example.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Paint by Example [Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen. diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md index a9a95759d6..9f61388dd5 100644 --- a/docs/source/en/api/pipelines/panorama.md +++ b/docs/source/en/api/pipelines/panorama.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # MultiDiffusion
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index a58d7fbe8d..7bd480b49a 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Image-to-Video Generation with PIA (Personalized Image Animator)
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md index f86cbc0b6f..5578fdfa63 100644 --- a/docs/source/en/api/pipelines/self_attention_guidance.md +++ b/docs/source/en/api/pipelines/self_attention_guidance.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Self-Attention Guidance [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al. diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md index 99395e75a9..1ce44cf2de 100644 --- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md +++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Semantic Guidance Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation. diff --git a/docs/source/en/api/pipelines/stable_diffusion/gligen.md b/docs/source/en/api/pipelines/stable_diffusion/gligen.md index 73be0b4ca8..e9704fc1de 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/gligen.md +++ b/docs/source/en/api/pipelines/stable_diffusion/gligen.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # GLIGEN (Grounded Language-to-Image Generation) The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs. diff --git a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md index 4d7fda2a0c..75f052b08f 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md +++ b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # K-Diffusion [k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion. diff --git a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md index 9f54538968..4c52ed90f0 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md +++ b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Text-to-(RGB, depth)
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md index ac5b97b672..1736491107 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Safe Stable Diffusion Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content. diff --git a/docs/source/en/api/pipelines/text_to_video.md b/docs/source/en/api/pipelines/text_to_video.md index 116aea736f..7faf88d133 100644 --- a/docs/source/en/api/pipelines/text_to_video.md +++ b/docs/source/en/api/pipelines/text_to_video.md @@ -10,11 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> - - -🧪 This pipeline is for research purposes only. - - +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. # Text-to-video diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md index 7966f43390..5fe3789d82 100644 --- a/docs/source/en/api/pipelines/text_to_video_zero.md +++ b/docs/source/en/api/pipelines/text_to_video_zero.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # Text2Video-Zero
diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md index c9a3164226..8011a4b533 100644 --- a/docs/source/en/api/pipelines/unclip.md +++ b/docs/source/en/api/pipelines/unclip.md @@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # unCLIP [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo). diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md index bce55b67ed..7d767f2db5 100644 --- a/docs/source/en/api/pipelines/unidiffuser.md +++ b/docs/source/en/api/pipelines/unidiffuser.md @@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. + # UniDiffuser
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md index 561df2017d..2be3631d84 100644 --- a/docs/source/en/api/pipelines/wuerstchen.md +++ b/docs/source/en/api/pipelines/wuerstchen.md @@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License. # Würstchen +> [!WARNING] +> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model. +
LoRA
From 5ef74fd5f641367c7be6b6cfab95338048d18580 Mon Sep 17 00:00:00 2001 From: Luo Yihang Date: Wed, 2 Jul 2025 11:37:54 +0800 Subject: [PATCH 13/14] fix norm not training in train_control_lora_flux.py (#11832) --- examples/flux-control/train_control_lora_flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 3c8b75a088..53ee0f89e2 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -837,11 +837,6 @@ def main(args): assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) - if args.train_norm_layers: - for name, param in flux_transformer.named_parameters(): - if any(k in name for k in NORM_LAYER_PREFIXES): - param.requires_grad = True - if args.lora_layers is not None: if args.lora_layers != "all-linear": target_modules = [layer.strip() for layer in args.lora_layers.split(",")] @@ -879,6 +874,11 @@ def main(args): ) flux_transformer.add_adapter(transformer_lora_config) + if args.train_norm_layers: + for name, param in flux_transformer.named_parameters(): + if any(k in name for k in NORM_LAYER_PREFIXES): + param.requires_grad = True + def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model From 0e95aa853edb85e6bf66634d544939c407f78d2f Mon Sep 17 00:00:00 2001 From: Ju Hoon Park Date: Wed, 2 Jul 2025 12:55:36 +0900 Subject: [PATCH 14/14] [From Single File] support `from_single_file` method for `WanVACE3DTransformer` (#11807) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add `WandVACETransformer3DModel` in`SINGLE_FILE_LOADABLE_CLASSES` * add rename keys for `VACE` add rename keys for `VACE` * fix typo Sincere thanks to @nitinmukesh 🙇‍♂️ * support for `1.3B VACE` model Sincere thanks to @nitinmukesh again🙇‍♂️ * update * update * Apply style fixes --------- Co-authored-by: Dhruv Nair Co-authored-by: github-actions[bot] --- src/diffusers/loaders/single_file_model.py | 4 ++ src/diffusers/loaders/single_file_utils.py | 14 ++++- tests/quantization/gguf/test_gguf.py | 70 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 2e99afbd51..17ac81ca26 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -136,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", }, + "WanVACETransformer3DModel": { + "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "AutoencoderKLWan": { "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "default_subfolder": "vae", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 3f81243693..ee0786aa2d 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = { ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", + "wan_vace": "vace_blocks.0.after_proj.bias", "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", "cosmos-1.0": [ "net.x_embedder.proj.1.weight", @@ -202,6 +203,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, + "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"}, + "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"}, "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"}, "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"}, @@ -716,7 +719,13 @@ def infer_diffusers_model_type(checkpoint): else: target_key = "patch_embedding.weight" - if checkpoint[target_key].shape[0] == 1536: + if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint: + if checkpoint[target_key].shape[0] == 1536: + model_type = "wan-vace-1.3B" + elif checkpoint[target_key].shape[0] == 5120: + model_type = "wan-vace-14B" + + elif checkpoint[target_key].shape[0] == 1536: model_type = "wan-t2v-1.3B" elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: model_type = "wan-t2v-14B" @@ -3132,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # For the VACE model + "before_proj": "proj_in", + "after_proj": "proj_out", } for key in list(checkpoint.keys()): diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 5d1fa4c22e..0d786de7e7 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -15,6 +15,8 @@ from diffusers import ( HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanTransformer3DModel, + WanVACETransformer3DModel, ) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( @@ -577,3 +579,71 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ).to(torch_device, self.torch_dtype), "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), } + + +class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + +class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "encoder_hidden_states_image": torch.randn( + (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0) + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + +class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanVACETransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + }