mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor
This commit is contained in:
@@ -42,6 +42,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -51,6 +53,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
@@ -68,7 +72,7 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if math.isclose(self.guidance_scale, 1.0):
|
||||
if self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
@@ -89,10 +93,16 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
|
||||
@@ -56,9 +56,13 @@ class ClassifierFreeGuidance(GuidanceMixin):
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
@@ -66,7 +70,7 @@ class ClassifierFreeGuidance(GuidanceMixin):
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if math.isclose(self.guidance_scale, 1.0):
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
@@ -81,6 +85,12 @@ class ClassifierFreeGuidance(GuidanceMixin):
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
@@ -48,6 +48,8 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -55,6 +57,8 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
@@ -65,7 +69,7 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif math.isclose(self.guidance_scale, 1.0):
|
||||
elif self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
@@ -85,10 +89,16 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond = cond.float()
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,15 +25,26 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
class GuidanceMixin:
|
||||
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
_input_predictions = None
|
||||
|
||||
def __init__(self):
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
@@ -63,15 +74,22 @@ class GuidanceMixin:
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if len(args) != self.num_conditions:
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
if len(kwargs) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments."
|
||||
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
|
||||
)
|
||||
return self.forward(*args)
|
||||
return self.forward(**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
|
||||
@@ -80,6 +98,10 @@ class GuidanceMixin:
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
return self._preds
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
|
||||
@@ -71,6 +71,8 @@ class SkipLayerGuidance(GuidanceMixin):
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -82,6 +84,8 @@ class SkipLayerGuidance(GuidanceMixin):
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
@@ -157,6 +161,18 @@ class SkipLayerGuidance(GuidanceMixin):
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
if not self._is_cfg_enabled() and self._is_slg_enabled():
|
||||
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
|
||||
# to avoid writing into pred_uncond which is not used
|
||||
if self._num_outputs_prepared == 2:
|
||||
key = "pred_cond_skip"
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
@@ -173,16 +189,16 @@ class SkipLayerGuidance(GuidanceMixin):
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0):
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif math.isclose(self.guidance_scale, 1.0):
|
||||
elif not self._is_cfg_enabled():
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
else:
|
||||
pred = pred_cond
|
||||
elif math.isclose(self.skip_layer_guidance_scale, 1.0):
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
@@ -203,12 +219,19 @@ class SkipLayerGuidance(GuidanceMixin):
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
if not math.isclose(self.guidance_scale, 1.0):
|
||||
num_conditions += 1
|
||||
if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step:
|
||||
num_conditions += 1
|
||||
|
||||
return num_conditions
|
||||
return skip_start_step < self._step < skip_stop_step
|
||||
|
||||
@@ -635,7 +635,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
crops_coords_top_left[0],
|
||||
)
|
||||
|
||||
noise_preds = []
|
||||
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
|
||||
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
|
||||
):
|
||||
@@ -652,9 +651,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_preds.append(noise_pred)
|
||||
guidance.prepare_outputs(noise_pred)
|
||||
|
||||
noise_pred = guidance(*noise_preds)
|
||||
outputs = guidance.outputs
|
||||
noise_pred = guidance(**outputs)
|
||||
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
|
||||
guidance.cleanup_models(self.transformer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user