1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

cache context refacotr; address review pt. 3

This commit is contained in:
Aryan
2025-04-16 14:24:08 +02:00
parent f731664773
commit 169bb0df9c
9 changed files with 122 additions and 129 deletions

View File

@@ -160,7 +160,7 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
@@ -293,18 +293,18 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _mark_state(self, name: str) -> None:
def _set_context(self, name: Optional[str] = None) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._mark_state(self._module_ref, name)
hook._set_context(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._mark_state(name)
module._diffusers_hook._set_context(name)
def __repr__(self) -> str:
registry_repr = ""

View File

@@ -118,19 +118,13 @@ class CacheMixin:
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def _cache_context(self):
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
cache_context = _CacheContextManager(self)
yield cache_context
class _CacheContextManager:
def __init__(self, model: CacheMixin):
self.model = model
def set_context(self, name: str) -> None:
from ..hooks import HookRegistry
if self.model.is_cache_enabled:
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry._mark_state(name)
if self.is_cache_enabled:
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
if self.is_cache_enabled:
registry._set_context(None)

View File

@@ -608,7 +608,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -619,24 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
cc.set_context("cond")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
cc.set_context("uncond")
noise_pred_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -645,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond

View File

@@ -906,7 +906,7 @@ class FluxPipeline(
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -917,35 +917,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
cc.set_context("uncond")
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -693,30 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
cc.set_context("uncond")
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -719,19 +719,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -1072,7 +1072,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -1105,16 +1105,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

View File

@@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -792,19 +792,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -521,7 +521,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -530,24 +530,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
cc.set_context("uncond")
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1