mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into custom-code-updates
This commit is contained in:
2
setup.py
2
setup.py
@@ -116,7 +116,7 @@ _deps = [
|
||||
"librosa",
|
||||
"numpy",
|
||||
"parameterized",
|
||||
"peft>=0.15.0",
|
||||
"peft>=0.17.0",
|
||||
"protobuf>=3.20.3,<4",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
|
||||
@@ -139,6 +139,7 @@ else:
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -23,7 +23,7 @@ deps = {
|
||||
"librosa": "librosa",
|
||||
"numpy": "numpy",
|
||||
"parameterized": "parameterized",
|
||||
"peft": "peft>=0.15.0",
|
||||
"peft": "peft>=0.17.0",
|
||||
"protobuf": "protobuf>=3.20.3,<4",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
@@ -32,6 +33,7 @@ if is_torch_available():
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
327
src/diffusers/guiders/frequency_decoupled_guidance.py
Normal file
327
src/diffusers/guiders/frequency_decoupled_guidance.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..utils import is_kornia_available
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
_CAN_USE_KORNIA = is_kornia_available()
|
||||
|
||||
|
||||
if _CAN_USE_KORNIA:
|
||||
from kornia.geometry import pyrup as upsample_and_blur_func
|
||||
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
||||
else:
|
||||
upsample_and_blur_func = None
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
"""
|
||||
# v0 shape: [B, ...]
|
||||
# v1 shape: [B, ...]
|
||||
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
||||
all_dims_but_first = list(range(1, len(v0.shape)))
|
||||
if upcast_to_double:
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
||||
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
if upcast_to_double:
|
||||
v0_parallel = v0_parallel.to(dtype)
|
||||
v0_orthogonal = v0_orthogonal.to(dtype)
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorihtm 2).
|
||||
"""
|
||||
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
||||
img = pyramid[-1]
|
||||
for i in range(len(pyramid) - 2, -1, -1):
|
||||
img = upsample_and_blur_func(img) + pyramid[i]
|
||||
return img
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
"""
|
||||
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
||||
|
||||
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
||||
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
||||
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
||||
how CFG works, you can check out the CFG guider.)
|
||||
|
||||
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
||||
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
||||
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
||||
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
||||
to form the final FDG prediction.
|
||||
|
||||
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
||||
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
||||
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
||||
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
||||
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
||||
|
||||
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
||||
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
||||
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
||||
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
||||
upcast_to_double (`bool`, defaults to `True`):
|
||||
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
||||
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
||||
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
||||
"kornia`."
|
||||
)
|
||||
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
|
||||
if isinstance(guidance_rescale, float):
|
||||
self.guidance_rescale = [guidance_rescale] * self.levels
|
||||
elif len(guidance_rescale) == self.levels:
|
||||
self.guidance_rescale = guidance_rescale
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
||||
# transforming from frequency space back to data space)
|
||||
if guidance_rescale_space not in ["data", "freq"]:
|
||||
raise ValueError(
|
||||
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
||||
)
|
||||
self.guidance_rescale_space = guidance_rescale_space
|
||||
|
||||
if parallel_weights is None:
|
||||
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
||||
self.parallel_weights = [1.0] * self.levels
|
||||
elif isinstance(parallel_weights, float):
|
||||
self.parallel_weights = [parallel_weights] * self.levels
|
||||
elif len(parallel_weights) == self.levels:
|
||||
self.parallel_weights = parallel_weights
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.upcast_to_double = upcast_to_double
|
||||
|
||||
if isinstance(start, float):
|
||||
self.guidance_start = [start] * self.levels
|
||||
elif len(start) == self.levels:
|
||||
self.guidance_start = start
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
if isinstance(stop, float):
|
||||
self.guidance_stop = [stop] * self.levels
|
||||
elif len(stop) == self.levels:
|
||||
self.guidance_stop = stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_fdg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
||||
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
||||
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
||||
|
||||
# From high frequencies to low frequencies, following the paper implementation
|
||||
pred_guided_pyramid = []
|
||||
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
||||
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
||||
if self._is_fdg_enabled_for_level(level):
|
||||
# Get the cond/uncond preds (in freq space) at the current frequency level
|
||||
pred_cond_freq = pred_cond_pyramid[level]
|
||||
pred_uncond_freq = pred_uncond_pyramid[level]
|
||||
|
||||
shift = pred_cond_freq - pred_uncond_freq
|
||||
|
||||
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
||||
if not math.isclose(parallel_weight, 1.0):
|
||||
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
||||
shift = parallel_weight * shift_parallel + shift_orthogonal
|
||||
|
||||
# Apply CFG update for the current frequency level
|
||||
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
||||
|
||||
# Add the current FDG guided level to the FDG prediction pyramid
|
||||
pred_guided_pyramid.append(pred)
|
||||
else:
|
||||
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
||||
pred_guided_pyramid.append(pred_cond_freq)
|
||||
|
||||
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
||||
pred = build_image_from_pyramid(pred_guided_pyramid)
|
||||
|
||||
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
||||
# across all freq levels
|
||||
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_fdg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_fdg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
||||
else:
|
||||
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
||||
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -133,6 +133,7 @@ def _register_attention_processors_metadata():
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# FluxAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=FluxAttnProcessor,
|
||||
|
||||
43
src/diffusers/hooks/utils.py
Normal file
43
src/diffusers/hooks/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
||||
|
||||
|
||||
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
||||
module_list_with_transformer_blocks = []
|
||||
for name, submodule in module.named_modules():
|
||||
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
||||
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_modulelist:
|
||||
module_list_with_transformer_blocks.append((name, submodule))
|
||||
return module_list_with_transformer_blocks
|
||||
|
||||
|
||||
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
|
||||
attention_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES):
|
||||
attention_layers.append((name, submodule))
|
||||
return attention_layers
|
||||
|
||||
|
||||
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
|
||||
feedforward_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
feedforward_layers.append((name, submodule))
|
||||
return feedforward_layers
|
||||
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
state_dict = {
|
||||
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
|
||||
@@ -320,7 +320,9 @@ class PeftAdapterMixin:
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
inject_adapter_in_model(
|
||||
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
||||
)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
|
||||
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
|
||||
@@ -310,7 +310,7 @@ class FluxPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -324,7 +324,7 @@ class FluxControlPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -374,7 +374,7 @@ class FluxControlInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -419,7 +419,7 @@ class FluxFillPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -358,7 +358,7 @@ class FluxKontextPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
).to(device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
|
||||
@@ -82,6 +82,7 @@ from .import_utils import (
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -398,6 +399,10 @@ def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -197,20 +197,6 @@ def get_peft_kwargs(
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
# Place other model-level components on `torch_device`.
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@require_bitsandbytes_version_greater("0.46.1")
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super().test_torch_compile()
|
||||
|
||||
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
|
||||
@@ -56,12 +56,18 @@ class QuantCompileTests:
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
# regional compilation is better for offloading.
|
||||
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
|
||||
if getattr(pipe.transformer, "_repeated_blocks"):
|
||||
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||
else:
|
||||
pipe.transformer.compile()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
Reference in New Issue
Block a user