1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-04-04 01:41:34 +02:00
parent 77324c40c4
commit 46643564a3
6 changed files with 99 additions and 24 deletions

View File

@@ -42,6 +42,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
we use the diffusers-native implementation that has been in the codebase for a long time.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
@@ -51,6 +53,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
@@ -68,7 +72,7 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if math.isclose(self.guidance_scale, 1.0):
if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
@@ -89,10 +93,16 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
@property
def num_conditions(self) -> int:
num_conditions = 1
if not math.isclose(self.guidance_scale, 1.0):
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
class MomentumBuffer:
def __init__(self, momentum: float):

View File

@@ -56,9 +56,13 @@ class ClassifierFreeGuidance(GuidanceMixin):
we use the diffusers-native implementation that has been in the codebase for a long time.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
super().__init__()
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
@@ -66,7 +70,7 @@ class ClassifierFreeGuidance(GuidanceMixin):
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if math.isclose(self.guidance_scale, 1.0):
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
@@ -81,6 +85,12 @@ class ClassifierFreeGuidance(GuidanceMixin):
@property
def num_conditions(self) -> int:
num_conditions = 1
if not math.isclose(self.guidance_scale, 1.0):
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)

View File

@@ -48,6 +48,8 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
@@ -55,6 +57,8 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
@@ -65,7 +69,7 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif math.isclose(self.guidance_scale, 1.0):
elif self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
@@ -85,10 +89,16 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
@property
def num_conditions(self) -> int:
num_conditions = 1
if not math.isclose(self.guidance_scale, 1.0):
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond = cond.float()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -25,15 +25,26 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
def __init__(self):
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._preds = {}
self._num_outputs_prepared = 0
def prepare_models(self, denoiser: torch.nn.Module) -> None:
pass
@@ -63,15 +74,22 @@ class GuidanceMixin:
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
pass
def __call__(self, *args) -> Any:
if len(args) != self.num_conditions:
def __call__(self, **kwargs) -> Any:
if len(kwargs) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments."
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
)
return self.forward(*args)
return self.forward(**kwargs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@@ -80,6 +98,10 @@ class GuidanceMixin:
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
@property
def outputs(self) -> Dict[str, torch.Tensor]:
return self._preds
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""

View File

@@ -71,6 +71,8 @@ class SkipLayerGuidance(GuidanceMixin):
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
@@ -82,6 +84,8 @@ class SkipLayerGuidance(GuidanceMixin):
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
@@ -157,6 +161,18 @@ class SkipLayerGuidance(GuidanceMixin):
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_slg_enabled():
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_cond_skip"
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
@@ -173,16 +189,16 @@ class SkipLayerGuidance(GuidanceMixin):
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0):
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif math.isclose(self.guidance_scale, 1.0):
elif not self._is_cfg_enabled():
if skip_start_step < self._step < skip_stop_step:
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
else:
pred = pred_cond
elif math.isclose(self.skip_layer_guidance_scale, 1.0):
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
@@ -203,12 +219,19 @@ class SkipLayerGuidance(GuidanceMixin):
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_slg_enabled(self) -> bool:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
if not math.isclose(self.guidance_scale, 1.0):
num_conditions += 1
if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step:
num_conditions += 1
return num_conditions
return skip_start_step < self._step < skip_stop_step

View File

@@ -635,7 +635,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
crops_coords_top_left[0],
)
noise_preds = []
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
):
@@ -652,9 +651,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_preds.append(noise_pred)
guidance.prepare_outputs(noise_pred)
noise_pred = guidance(*noise_preds)
outputs = guidance.outputs
noise_pred = guidance(**outputs)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)