From f731664773d4dd79471b2f6befd4c3aaa3b4bb85 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 14:03:01 +0200 Subject: [PATCH] address review comments pt. 2 --- src/diffusers/hooks/first_block_cache.py | 4 +- src/diffusers/hooks/hooks.py | 63 +++++++++---------- src/diffusers/models/cache_utils.py | 2 +- .../pipelines/cogview4/pipeline_cogview4.py | 4 +- src/diffusers/pipelines/flux/pipeline_flux.py | 4 +- .../hunyuan_video/pipeline_hunyuan_video.py | 4 +- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 +- .../pipelines/ltx/pipeline_ltx_condition.py | 2 +- .../pipelines/ltx/pipeline_ltx_image2video.py | 2 +- src/diffusers/pipelines/wan/pipeline_wan.py | 4 +- 10 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 6ce4015b63..b232e6465c 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -21,7 +21,7 @@ from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import BaseMarkedState, HookRegistry, ModelHook +from .hooks import ContextAwareState, HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name @@ -49,7 +49,7 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState(BaseMarkedState): +class FBCSharedBlockState(ContextAwareState): def __init__(self) -> None: super().__init__() diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 814529d0b2..16e80add84 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -31,61 +31,56 @@ class BaseState: ) -class BaseMarkedState(BaseState): +class ContextAwareState(BaseState): def __init__(self, init_args=None, init_kwargs=None): super().__init__() self._init_args = init_args if init_args is not None else () self._init_kwargs = init_kwargs if init_kwargs is not None else {} - self._mark_name = None + self._current_context = None self._state_cache = {} - def get_current_state(self) -> "BaseMarkedState": - if self._mark_name is None: - # If no mark name is set, simply return a dummy object since we're not going to be using it + def get_state(self) -> "ContextAwareState": + if self._current_context is None: + # If no context is set, simply return a dummy object since we're not going to be using it return self - if self._mark_name not in self._state_cache.keys(): - self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) - return self._state_cache[self._mark_name] + if self._current_context not in self._state_cache.keys(): + self._state_cache[self._current_context] = ContextAwareState._create_state( + self.__class__, self._init_args, self._init_kwargs + ) + return self._state_cache[self._current_context] - def mark_state(self, name: str) -> None: - self._mark_name = name + def set_context(self, name: str) -> None: + self._current_context = name def reset(self, *args, **kwargs) -> None: for name, state in list(self._state_cache.items()): state.reset(*args, **kwargs) self._state_cache.pop(name) - self._mark_name = None + self._current_context = None + + @staticmethod + def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState": + return cls(*init_args, **init_kwargs) def __getattribute__(self, name): - direct_attrs = ( - "get_current_state", - "mark_state", - "reset", - "_init_args", - "_init_kwargs", - "_mark_name", - "_state_cache", - ) + # fmt: off + direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") + # fmt: on if name in direct_attrs or _is_dunder_method(name): return object.__getattribute__(self, name) else: - current_state = BaseMarkedState.get_current_state(self) + current_state = ContextAwareState.get_state(self) return object.__getattribute__(current_state, name) def __setattr__(self, name, value): - if name in ( - "get_current_state", - "mark_state", - "reset", - "_init_args", - "_init_kwargs", - "_mark_name", - "_state_cache", - ) or _is_dunder_method(name): + # fmt: off + direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") + # fmt: on + if name in direct_attrs or _is_dunder_method(name): object.__setattr__(self, name, value) else: - current_state = BaseMarkedState.get_current_state(self) + current_state = ContextAwareState.get_state(self) object.__setattr__(current_state, name, value) @@ -166,11 +161,11 @@ class ModelHook: return module def _mark_state(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. + # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) - if isinstance(attr, BaseMarkedState): - attr.mark_state(name) + if isinstance(attr, ContextAwareState): + attr.set_context(name) return module diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 6c4bcb301d..7ff9f6b84c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -128,7 +128,7 @@ class _CacheContextManager: def __init__(self, model: CacheMixin): self.model = model - def mark_state(self, name: str) -> None: + def set_context(self, name: str) -> None: from ..hooks import HookRegistry if self.model.is_cache_enabled: diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 876ca922a6..46cb39e2a4 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -619,7 +619,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - cc.mark_state("cond") + cc.set_context("cond") noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, @@ -633,7 +633,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): # perform guidance if self.do_classifier_free_guidance: - cc.mark_state("uncond") + cc.set_context("uncond") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index a7195d3a67..e9155fd640 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -917,7 +917,7 @@ class FluxPipeline( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -934,7 +934,7 @@ class FluxPipeline( if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - cc.mark_state("uncond") + cc.set_context("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index b36de61c02..2355b90100 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -693,7 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -706,7 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): )[0] if do_true_cfg: - cc.mark_state("uncond") + cc.set_context("uncond") neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 2fa9fa53e8..3f3881e49f 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -719,7 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index aa7e8eb559..5458b473b5 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1105,7 +1105,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index e9d2566a9b..1317acd8ba 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -792,7 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index e38cb34a66..78de45f4ab 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -530,7 +530,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -540,7 +540,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): )[0] if self.do_classifier_free_guidance: - cc.mark_state("uncond") + cc.set_context("uncond") noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep,