From 169bb0df9ce724f2adc5d11c76454541a515a685 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 14:24:08 +0200 Subject: [PATCH] cache context refacotr; address review pt. 3 --- src/diffusers/hooks/hooks.py | 8 ++-- src/diffusers/models/cache_utils.py | 20 +++------ .../pipelines/cogview4/pipeline_cogview4.py | 35 +++++++-------- src/diffusers/pipelines/flux/pipeline_flux.py | 44 +++++++++---------- .../hunyuan_video/pipeline_hunyuan_video.py | 38 ++++++++-------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 28 ++++++------ .../pipelines/ltx/pipeline_ltx_condition.py | 22 +++++----- .../pipelines/ltx/pipeline_ltx_image2video.py | 28 ++++++------ src/diffusers/pipelines/wan/pipeline_wan.py | 28 ++++++------ 9 files changed, 122 insertions(+), 129 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 16e80add84..4ca5761f75 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -160,7 +160,7 @@ class ModelHook: raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module - def _mark_state(self, module: torch.nn.Module, name: str) -> None: + def _set_context(self, module: torch.nn.Module, name: str) -> None: # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) @@ -293,18 +293,18 @@ class HookRegistry: module._diffusers_hook = cls(module) return module._diffusers_hook - def _mark_state(self, name: str) -> None: + def _set_context(self, name: Optional[str] = None) -> None: for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: - hook._mark_state(self._module_ref, name) + hook._set_context(self._module_ref, name) for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): - module._diffusers_hook._mark_state(name) + module._diffusers_hook._set_context(name) def __repr__(self) -> str: registry_repr = "" diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 7ff9f6b84c..b251850ced 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -118,19 +118,13 @@ class CacheMixin: HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) @contextmanager - def _cache_context(self): + def cache_context(self, name: str): r"""Context manager that provides additional methods for cache management.""" - cache_context = _CacheContextManager(self) - yield cache_context - - -class _CacheContextManager: - def __init__(self, model: CacheMixin): - self.model = model - - def set_context(self, name: str) -> None: from ..hooks import HookRegistry - if self.model.is_cache_enabled: - registry = HookRegistry.check_if_exists_or_initialize(self.model) - registry._mark_state(name) + if self.is_cache_enabled: + registry = HookRegistry.check_if_exists_or_initialize(self) + registry._set_context(name) + yield + if self.is_cache_enabled: + registry._set_context(None) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 46cb39e2a4..c3a6d7991b 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -608,7 +608,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -619,24 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - cc.set_context("cond") - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - cc.set_context("uncond") - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, @@ -645,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): return_dict=False, )[0] + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e9155fd640..cfd0eb2715 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ class FluxPipeline( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,35 +917,35 @@ class FluxPipeline( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - - cc.set_context("uncond") - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 2355b90100..5e60b29c31 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -693,30 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - cc.set_context("uncond") - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 3f3881e49f..81df4ca938 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -719,19 +719,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 5458b473b5..481ed0fd55 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1072,7 +1072,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL self._num_timesteps = len(timesteps) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1105,16 +1105,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - video_coords=video_coords, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 1317acd8ba..acd500f9fb 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -792,19 +792,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 78de45f4ab..59d07fa24f 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -521,7 +521,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -530,24 +530,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - cc.set_context("uncond") - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1