1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

classifier-free guidance

This commit is contained in:
Aryan
2025-04-03 00:13:15 +02:00
parent c76e1cc17e
commit 594e8d663f
6 changed files with 260 additions and 34 deletions

View File

@@ -33,6 +33,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
@@ -129,6 +130,7 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(["ClassifierFreeGuidance"])
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
@@ -710,6 +712,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import ClassifierFreeGuidance
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,

View File

@@ -0,0 +1,20 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .classifier_free_guidance import ClassifierFreeGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning

View File

@@ -0,0 +1,86 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeGuidance(GuidanceMixin):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what
the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time.
"""
def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False):
self.scale = scale
self.rescale = rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
if math.isclose(self.scale, 1.0):
return pred_cond
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.scale * shift
if self.rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.rescale)
return pred
@property
def num_conditions(self) -> int:
if math.isclose(self.scale, 1.0):
return 1
return 2
@property
def guidance_scale(self) -> float:
return self.scale
@property
def guidance_rescale(self) -> float:
return self.rescale

View File

@@ -0,0 +1,96 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
import torch
from ..utils import deprecate, get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
inputs = [x for x in arg if x is not None]
if len(inputs) < num_conditions:
raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.")
list_of_inputs.append(inputs[:num_conditions])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def __call__(self, *args) -> Any:
if len(args) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments."
)
return self.forward(*args)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@property
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def _raise_guidance_deprecation_warning(
*,
guidance_scale: Optional[float] = None,
guidance_rescale: Optional[float] = None,
) -> None:
if guidance_scale is not None:
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
if guidance_rescale is not None:
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)

View File

@@ -21,6 +21,7 @@ import torch
from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, CogView4Transformer2DModel
@@ -428,6 +429,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 1024,
guidance: Optional[GuidanceMixin] = None,
) -> Union[CogView4PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -516,6 +518,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(scale=guidance_scale)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -606,52 +612,45 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
self._num_timesteps = len(timesteps)
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
latents,
(prompt_embeds, negative_prompt_embeds),
original_size,
target_size,
crops_coords_top_left,
)
# Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
cc.mark_state("cond")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
noise_preds = []
for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
):
cc.mark_state(f"batch_{i}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
hidden_states=latent,
encoder_hidden_states=condition,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
original_size=original_size_c,
target_size=target_size_c,
crop_coords=crop_coord_c,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_preds.append(noise_pred)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
noise_pred = guidance(*noise_preds)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
@@ -660,8 +659,14 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
negative_prompt_embeds = [
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]
latents, prompt_embeds = guidance.prepare_inputs(
latents, (prompt_embeds[0], negative_prompt_embeds[0])
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -670,6 +675,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
xm.mark_step()
self._current_timestep = None
latents = latents[0]
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor

View File

@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]