mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
address review comments pt. 2
This commit is contained in:
@@ -21,7 +21,7 @@ from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseMarkedState, HookRegistry, ModelHook
|
||||
from .hooks import ContextAwareState, HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -49,7 +49,7 @@ class FirstBlockCacheConfig:
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseMarkedState):
|
||||
class FBCSharedBlockState(ContextAwareState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -31,61 +31,56 @@ class BaseState:
|
||||
)
|
||||
|
||||
|
||||
class BaseMarkedState(BaseState):
|
||||
class ContextAwareState(BaseState):
|
||||
def __init__(self, init_args=None, init_kwargs=None):
|
||||
super().__init__()
|
||||
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._mark_name = None
|
||||
self._current_context = None
|
||||
self._state_cache = {}
|
||||
|
||||
def get_current_state(self) -> "BaseMarkedState":
|
||||
if self._mark_name is None:
|
||||
# If no mark name is set, simply return a dummy object since we're not going to be using it
|
||||
def get_state(self) -> "ContextAwareState":
|
||||
if self._current_context is None:
|
||||
# If no context is set, simply return a dummy object since we're not going to be using it
|
||||
return self
|
||||
if self._mark_name not in self._state_cache.keys():
|
||||
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._mark_name]
|
||||
if self._current_context not in self._state_cache.keys():
|
||||
self._state_cache[self._current_context] = ContextAwareState._create_state(
|
||||
self.__class__, self._init_args, self._init_kwargs
|
||||
)
|
||||
return self._state_cache[self._current_context]
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
self._mark_name = name
|
||||
def set_context(self, name: str) -> None:
|
||||
self._current_context = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._mark_name = None
|
||||
self._current_context = None
|
||||
|
||||
@staticmethod
|
||||
def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState":
|
||||
return cls(*init_args, **init_kwargs)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
direct_attrs = (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
)
|
||||
# fmt: off
|
||||
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
|
||||
# fmt: on
|
||||
if name in direct_attrs or _is_dunder_method(name):
|
||||
return object.__getattribute__(self, name)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
current_state = ContextAwareState.get_state(self)
|
||||
return object.__getattribute__(current_state, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
) or _is_dunder_method(name):
|
||||
# fmt: off
|
||||
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
|
||||
# fmt: on
|
||||
if name in direct_attrs or _is_dunder_method(name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
current_state = ContextAwareState.get_state(self)
|
||||
object.__setattr__(current_state, name, value)
|
||||
|
||||
|
||||
@@ -166,11 +161,11 @@ class ModelHook:
|
||||
return module
|
||||
|
||||
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, BaseMarkedState):
|
||||
attr.mark_state(name)
|
||||
if isinstance(attr, ContextAwareState):
|
||||
attr.set_context(name)
|
||||
return module
|
||||
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class _CacheContextManager:
|
||||
def __init__(self, model: CacheMixin):
|
||||
self.model = model
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
def set_context(self, name: str) -> None:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
if self.model.is_cache_enabled:
|
||||
|
||||
@@ -619,7 +619,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
cc.mark_state("cond")
|
||||
cc.set_context("cond")
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -633,7 +633,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
cc.mark_state("uncond")
|
||||
cc.set_context("uncond")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
|
||||
@@ -917,7 +917,7 @@ class FluxPipeline(
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
cc.set_context("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
@@ -934,7 +934,7 @@ class FluxPipeline(
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
cc.mark_state("uncond")
|
||||
cc.set_context("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
|
||||
@@ -693,7 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
cc.set_context("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -706,7 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
cc.mark_state("uncond")
|
||||
cc.set_context("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
@@ -719,7 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
cc.set_context("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -1105,7 +1105,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if is_conditioning_image_or_video:
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
cc.set_context("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -792,7 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
cc.set_context("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -530,7 +530,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
cc.mark_state("cond")
|
||||
cc.set_context("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -540,7 +540,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
cc.mark_state("uncond")
|
||||
cc.set_context("uncond")
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
Reference in New Issue
Block a user