1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

address review comments pt. 2

This commit is contained in:
Aryan
2025-04-16 14:03:01 +02:00
parent 3dde07a647
commit f731664773
10 changed files with 43 additions and 48 deletions

View File

@@ -21,7 +21,7 @@ from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseMarkedState, HookRegistry, ModelHook
from .hooks import ContextAwareState, HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -49,7 +49,7 @@ class FirstBlockCacheConfig:
threshold: float = 0.05
class FBCSharedBlockState(BaseMarkedState):
class FBCSharedBlockState(ContextAwareState):
def __init__(self) -> None:
super().__init__()

View File

@@ -31,61 +31,56 @@ class BaseState:
)
class BaseMarkedState(BaseState):
class ContextAwareState(BaseState):
def __init__(self, init_args=None, init_kwargs=None):
super().__init__()
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._mark_name = None
self._current_context = None
self._state_cache = {}
def get_current_state(self) -> "BaseMarkedState":
if self._mark_name is None:
# If no mark name is set, simply return a dummy object since we're not going to be using it
def get_state(self) -> "ContextAwareState":
if self._current_context is None:
# If no context is set, simply return a dummy object since we're not going to be using it
return self
if self._mark_name not in self._state_cache.keys():
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
return self._state_cache[self._mark_name]
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = ContextAwareState._create_state(
self.__class__, self._init_args, self._init_kwargs
)
return self._state_cache[self._current_context]
def mark_state(self, name: str) -> None:
self._mark_name = name
def set_context(self, name: str) -> None:
self._current_context = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._mark_name = None
self._current_context = None
@staticmethod
def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState":
return cls(*init_args, **init_kwargs)
def __getattribute__(self, name):
direct_attrs = (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
)
# fmt: off
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
# fmt: on
if name in direct_attrs or _is_dunder_method(name):
return object.__getattribute__(self, name)
else:
current_state = BaseMarkedState.get_current_state(self)
current_state = ContextAwareState.get_state(self)
return object.__getattribute__(current_state, name)
def __setattr__(self, name, value):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
# fmt: off
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
# fmt: on
if name in direct_attrs or _is_dunder_method(name):
object.__setattr__(self, name, value)
else:
current_state = BaseMarkedState.get_current_state(self)
current_state = ContextAwareState.get_state(self)
object.__setattr__(current_state, name, value)
@@ -166,11 +161,11 @@ class ModelHook:
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
# Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseMarkedState):
attr.mark_state(name)
if isinstance(attr, ContextAwareState):
attr.set_context(name)
return module

View File

@@ -128,7 +128,7 @@ class _CacheContextManager:
def __init__(self, model: CacheMixin):
self.model = model
def mark_state(self, name: str) -> None:
def set_context(self, name: str) -> None:
from ..hooks import HookRegistry
if self.model.is_cache_enabled:

View File

@@ -619,7 +619,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
cc.mark_state("cond")
cc.set_context("cond")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
@@ -633,7 +633,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# perform guidance
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
cc.set_context("uncond")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,

View File

@@ -917,7 +917,7 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
@@ -934,7 +934,7 @@ class FluxPipeline(
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
cc.mark_state("uncond")
cc.set_context("uncond")
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,

View File

@@ -693,7 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -706,7 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
)[0]
if do_true_cfg:
cc.mark_state("uncond")
cc.set_context("uncond")
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,

View File

@@ -719,7 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
cc.mark_state("cond_uncond")
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -1105,7 +1105,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
cc.mark_state("cond_uncond")
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -792,7 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
cc.mark_state("cond_uncond")
cc.set_context("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -530,7 +530,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
cc.mark_state("cond")
cc.set_context("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -540,7 +540,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0]
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
cc.set_context("uncond")
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,