diff --git a/setup.py b/setup.py index 799150fd03..e0c810a920 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ _deps = [ "librosa", "numpy", "parameterized", - "peft>=0.15.0", + "peft>=0.17.0", "protobuf>=3.20.3,<4", "pytest", "pytest-timeout", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1c25a65f50..6d2b88aef0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -139,6 +139,7 @@ else: "AutoGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", + "FrequencyDecoupledGuidance", "PerturbedAttentionGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", @@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + FrequencyDecoupledGuidance, PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 3d14a8b3e0..a3832cf9b8 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -23,7 +23,7 @@ deps = { "librosa": "librosa", "numpy": "numpy", "parameterized": "parameterized", - "peft": "peft>=0.15.0", + "peft": "peft>=0.17.0", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 1c288f00f0..23cb7a0a71 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -22,6 +22,7 @@ if is_torch_available(): from .auto_guidance import AutoGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance @@ -32,6 +33,7 @@ if is_torch_available(): AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + FrequencyDecoupledGuidance, PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py new file mode 100644 index 0000000000..35bc99ac4d --- /dev/null +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -0,0 +1,327 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from ..utils import is_kornia_available +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +_CAN_USE_KORNIA = is_kornia_available() + + +if _CAN_USE_KORNIA: + from kornia.geometry import pyrup as upsample_and_blur_func + from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func +else: + upsample_and_blur_func = None + build_laplacian_pyramid_func = None + + +def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper + (Algorithm 2). + """ + # v0 shape: [B, ...] + # v1 shape: [B, ...] + # Assume first dim is a batch dim and all other dims are channel or "spatial" dims + all_dims_but_first = list(range(1, len(v0.shape))) + if upcast_to_double: + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() + v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first) + v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + if upcast_to_double: + v0_parallel = v0_parallel.to(dtype) + v0_orthogonal = v0_orthogonal.to(dtype) + return v0_parallel, v0_orthogonal + + +def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: + """ + Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper + (Algorihtm 2). + """ + # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...] + img = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): + img = upsample_and_blur_func(img) + pyramid[i] + return img + + +class FrequencyDecoupledGuidance(BaseGuidance): + """ + Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713 + + FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation + quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both + conditional and unconditional data, and use a combination of the two during inference. (If you want more details on + how CFG works, you can check out the CFG guider.) + + FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components + using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space + separately for the low- and high-frequency components with different guidance scales. Finally, the inverse + frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images) + to form the final FDG prediction. + + For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample + diversity and realistic color composition, while using high guidance scales for high-frequency components enhances + sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for + the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an + example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper). + + As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`): + The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest + frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower + values allow for more freedom in generation. Higher values may lead to saturation and deterioration of + image quality. The FDG authors recommend using higher guidance scales for higher frequency components and + lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in + descending order). + guidance_rescale (`float` or `List[float]`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as + `guidance_scales`. + parallel_weights (`float` or `List[float]`, *optional*): + Optional weights for the parallel component of each frequency component of the projected CFG shift. If not + set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift + (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is + recommended. If a list is supplied, it should be the same length as `guidance_scales`. + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float` or `List[float]`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it + should be the same length as `guidance_scales`. + stop (`float` or `List[float]`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it + should be the same length as `guidance_scales`. + guidance_rescale_space (`str`, defaults to `"data"`): + Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in + `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is + speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value + will be used; otherwise, per-frequency-level guidance rescale values will be used if available. + upcast_to_double (`bool`, defaults to `True`): + Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to + float64 when performing guidance. This may result in better performance at the cost of increased runtime. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0], + guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0, + parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None, + use_original_formulation: bool = False, + start: Union[float, List[float], Tuple[float]] = 0.0, + stop: Union[float, List[float], Tuple[float]] = 1.0, + guidance_rescale_space: str = "data", + upcast_to_double: bool = True, + ): + if not _CAN_USE_KORNIA: + raise ImportError( + "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which " + "it depends is not available in the current environment. You can install `kornia` with `pip install " + "kornia`." + ) + + # Set start to earliest start for any freq component and stop to latest stop for any freq component + min_start = start if isinstance(start, float) else min(start) + max_stop = stop if isinstance(stop, float) else max(stop) + super().__init__(min_start, max_stop) + + self.guidance_scales = guidance_scales + self.levels = len(guidance_scales) + + if isinstance(guidance_rescale, float): + self.guidance_rescale = [guidance_rescale] * self.levels + elif len(guidance_rescale) == self.levels: + self.guidance_rescale = guidance_rescale + else: + raise ValueError( + f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as " + f"`guidance_scales` ({len(self.guidance_scales)})" + ) + # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after + # transforming from frequency space back to data space) + if guidance_rescale_space not in ["data", "freq"]: + raise ValueError( + f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`." + ) + self.guidance_rescale_space = guidance_rescale_space + + if parallel_weights is None: + # Use normal CFG shift (equal weights for parallel and orthogonal components) + self.parallel_weights = [1.0] * self.levels + elif isinstance(parallel_weights, float): + self.parallel_weights = [parallel_weights] * self.levels + elif len(parallel_weights) == self.levels: + self.parallel_weights = parallel_weights + else: + raise ValueError( + f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as " + f"`guidance_scales` ({len(self.guidance_scales)})" + ) + + self.use_original_formulation = use_original_formulation + self.upcast_to_double = upcast_to_double + + if isinstance(start, float): + self.guidance_start = [start] * self.levels + elif len(start) == self.levels: + self.guidance_start = start + else: + raise ValueError( + f"`start` has length {len(start)} but should have the same length as `guidance_scales` " + f"({len(self.guidance_scales)})" + ) + if isinstance(stop, float): + self.guidance_stop = [stop] * self.levels + elif len(stop) == self.levels: + self.guidance_stop = stop + else: + raise ValueError( + f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` " + f"({len(self.guidance_scales)})" + ) + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_fdg_enabled(): + pred = pred_cond + else: + # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions. + pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels) + pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels) + + # From high frequencies to low frequencies, following the paper implementation + pred_guided_pyramid = [] + parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale) + for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters): + if self._is_fdg_enabled_for_level(level): + # Get the cond/uncond preds (in freq space) at the current frequency level + pred_cond_freq = pred_cond_pyramid[level] + pred_uncond_freq = pred_uncond_pyramid[level] + + shift = pred_cond_freq - pred_uncond_freq + + # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) + if not math.isclose(parallel_weight, 1.0): + shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double) + shift = parallel_weight * shift_parallel + shift_orthogonal + + # Apply CFG update for the current frequency level + pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq + pred = pred + guidance_scale * shift + + if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale) + + # Add the current FDG guided level to the FDG prediction pyramid + pred_guided_pyramid.append(pred) + else: + # Add the current pred_cond_pyramid level as the "non-FDG" prediction + pred_guided_pyramid.append(pred_cond_freq) + + # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform + pred = build_image_from_pyramid(pred_guided_pyramid) + + # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value + # across all freq levels + if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_fdg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_fdg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales) + else: + is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales) + + return is_within_range and not is_close + + def _is_fdg_enabled_for_level(self, level: int) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.guidance_start[level] * self._num_inference_steps) + skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scales[level], 0.0) + else: + is_close = math.isclose(self.guidance_scales[level], 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index f328078ce4..c36c0c31ea 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -133,6 +133,7 @@ def _register_attention_processors_metadata(): skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, ), ) + # FluxAttnProcessor AttentionProcessorRegistry.register( model_class=FluxAttnProcessor, diff --git a/src/diffusers/hooks/utils.py b/src/diffusers/hooks/utils.py new file mode 100644 index 0000000000..c5260eeebe --- /dev/null +++ b/src/diffusers/hooks/utils.py @@ -0,0 +1,43 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES + + +def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module): + module_list_with_transformer_blocks = [] + for name, submodule in module.named_modules(): + name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS) + is_modulelist = isinstance(submodule, torch.nn.ModuleList) + if name_endswith_identifier and is_modulelist: + module_list_with_transformer_blocks.append((name, submodule)) + return module_list_with_transformer_blocks + + +def _get_identifiable_attention_layers_in_module(module: torch.nn.Module): + attention_layers = [] + for name, submodule in module.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES): + attention_layers.append((name, submodule)) + return attention_layers + + +def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module): + feedforward_layers = [] + for name, submodule in module.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + feedforward_layers.append((name, submodule)) + return feedforward_layers diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index ba96dccbe3..6e8b356055 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): # has both `peft` and non-peft state dict. has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict) if has_peft_state_dict: - state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")} + state_dict = { + k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v + for k, v in state_dict.items() + if k.startswith("transformer.") + } return state_dict # Another weird one. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d048298fd4..2381ccfef3 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -320,7 +320,9 @@ class PeftAdapterMixin: # it to None incompatible_keys = None else: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + inject_adapter_in_model( + lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs + ) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) if self._prepare_lora_hotswap_kwargs is not None: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 9080cd508d..60c7eb1dba 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 7211fb5693..124e611bd0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -310,7 +310,7 @@ class FluxPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 5a057f94cf..51d6ecbe31 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -324,7 +324,7 @@ class FluxControlPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 8d5439daf6..c61d46daef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 872bcf177c..3de636361b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -374,7 +374,7 @@ class FluxControlInpaintPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 1438d4a902..a39b9c9ce2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 52e15de53b..582c7bbad8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d1e874d0b8..f7f34ef231 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index ddfb284eaf..d50db407a8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -419,7 +419,7 @@ class FluxFillPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 1c4cf3b1cd..08e2f12778 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index eeacd9b19b..0494146693 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 3c78aeaf36..ce2941f3dd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -358,7 +358,7 @@ class FluxKontextPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index 6dc621901c..56a5e934a4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline( def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b5ccfb31a3..e79db337b2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline): def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 03f6f73b44..47549ab4af 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) + ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 75a2bdd13e..5f49f5e757 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -82,6 +82,7 @@ from .import_utils import ( is_k_diffusion_available, is_k_diffusion_version, is_kernels_available, + is_kornia_available, is_librosa_available, is_matplotlib_available, is_nltk_available, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 35df559ce4..08a816ce4b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class FrequencyDecoupledGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PerturbedAttentionGuidance(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index d8b26bda46..ac209afb74 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -224,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available(" _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_kornia_available, _kornia_version = _is_package_available("kornia") def is_torch_available(): @@ -398,6 +399,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_kornia_available(): + return _kornia_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 651fa27294..12066ee3f8 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -197,20 +197,6 @@ def get_peft_kwargs( "lora_bias": lora_bias, } - # Example: try load FusionX LoRA into Wan VACE - exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) - if exclude_modules: - if not is_peft_version(">=", "0.14.0"): - msg = """ -It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` -version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U -peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue - -https://github.com/huggingface/diffusers/issues/new - """ - logger.debug(msg) - else: - lora_config_kwargs.update({"exclude_modules": exclude_modules}) - return lora_config_kwargs @@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): if warn_msg: logger.warning(warn_msg) - - -def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): - """ - Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the - `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it - doesn't exist in `peft_state_dict`. - """ - if model_state_dict is None: - return - all_modules = set() - string_to_replace = f"{adapter_name}." if adapter_name else "" - - for name in model_state_dict.keys(): - if string_to_replace: - name = name.replace(string_to_replace, "") - if "." in name: - module_name = name.rsplit(".", 1)[0] - all_modules.add(module_name) - - target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} - exclude_modules = list(all_modules - target_modules_set) - - return exclude_modules diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 9edaeafc71..f09f0d8ecb 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests: return modules_to_save - def _get_exclude_modules(self, pipe): - from diffusers.utils.peft_utils import _derive_exclude_modules - - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - denoiser = "unet" if self.unet_kwargs is not None else "transformer" - modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} - denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] - pipe.unload_lora_weights() - denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() - exclude_modules = _derive_exclude_modules( - denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" - ) - return exclude_modules - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests: ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules(self): - """ - Test to check if `exclude_modules` works or not. It works in the following way: - we first create a pipeline and insert LoRA config into it. We then derive a `set` - of modules to exclude by investigating its denoiser state dict and denoiser LoRA - state dict. - - We then create a new LoRA config to include the `exclude_modules` and perform tests. - """ - scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - # only supported for `denoiser` now - pipe_cp = copy.deepcopy(pipe) - pipe_cp, _ = self.add_adapters_to_pipeline( - pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) - pipe_cp.to("cpu") - del pipe_cp - - denoiser_lora_config.exclude_modules = denoiser_exclude_modules - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config - ) - output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), - "LoRA should change outputs.", - ) - self.assertTrue( - np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), - "Lora outputs should match.", - ) - def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: @@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests: components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet @@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests: num_blocks_per_group=1, use_stream=use_stream, ) + # Place other model-level components on `torch_device`. + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) self.assertTrue(group_offload_hook_1 is not None) output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 68b5c02bc0..14ef6f1514 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -20,7 +20,7 @@ import torch from diffusers import FluxTransformer2DModel from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection -from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin @@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): expected_set = {"FluxTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + # The test exists for cases like + # https://github.com/huggingface/diffusers/issues/11874 + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_exclude_modules(self): + from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict + + lora_rank = 4 + target_module = "single_transformer_blocks.0.proj_out" + adapter_name = "foo" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + state_dict = model.state_dict() + target_mod_shape = state_dict[f"{target_module}.weight"].shape + lora_state_dict = { + f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, + f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, + } + # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). + config = LoraConfig( + r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] + ) + inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) + set_peft_model_state_dict(model, lora_state_dict, adapter_name) + retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) + assert len(retrieved_lora_state_dict) == len(lora_state_dict) + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() + class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 8e2a8515c6..08c0fee43b 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase): components_to_quantize=["transformer", "text_encoder_2"], ) + @require_bitsandbytes_version_greater("0.46.1") def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True super().test_torch_compile() diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 64f56b02b0..8ddbf11cfd 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): components_to_quantize=["transformer", "text_encoder_2"], ) + @pytest.mark.xfail( + reason="Test fails because of an offloading problem from Accelerate with confusion in hooks." + " Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details." + ) def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True super()._test_torch_compile(torch_dtype=torch.float16) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index c742927646..91ed173fc6 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -56,12 +56,18 @@ class QuantCompileTests: pipe.transformer.compile(fullgraph=True) # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) + with torch._dynamo.config.patch(error_on_recompile=True): + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe.enable_model_cpu_offload() - pipe.transformer.compile() + # regional compilation is better for offloading. + # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/ + if getattr(pipe.transformer, "_repeated_blocks"): + pipe.transformer.compile_repeated_blocks(fullgraph=True) + else: + pipe.transformer.compile() # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)