mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
guiders support for wan
This commit is contained in:
@@ -30,7 +30,7 @@ from ..models.transformers.transformer_hunyuan_video import (
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -186,6 +186,14 @@ def _register_guidance_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
GuidanceMetadataRegistry.register(
|
||||
model_class=WanAttnProcessor2_0,
|
||||
metadata=GuidanceMetadata(
|
||||
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -467,3 +467,85 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
### ===== Custom attention processors for guidance methods ===== ###
|
||||
|
||||
|
||||
class WanPAGAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
is_encoder_hidden_states_provided = encoder_hidden_states is not None
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
||||
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
if is_encoder_hidden_states_provided:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
else:
|
||||
# Perturbed attention applied only when self-attention
|
||||
hidden_states = value
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -617,7 +617,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
|
||||
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
|
||||
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[c] for c in conds]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -380,6 +381,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
guidance: Optional[GuidanceMixin] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -444,6 +446,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
|
||||
if guidance is None:
|
||||
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -519,37 +525,38 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
conds = [prompt_embeds, negative_prompt_embeds]
|
||||
prompt_embeds, negative_prompt_embeds = [[c] for c in conds]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
self._current_timestep = t
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
guidance.prepare_models(self.transformer)
|
||||
latents, prompt_embeds = guidance.prepare_inputs(
|
||||
latents, (prompt_embeds[0], negative_prompt_embeds[0])
|
||||
)
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
cc.mark_state("uncond")
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)):
|
||||
cc.mark_state(f"batch_{batch_index}")
|
||||
latent = latent.to(transformer_dtype)
|
||||
timestep = t.expand(latent.shape[0])
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=condition,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
guidance.prepare_outputs(noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
outputs = guidance.outputs
|
||||
noise_pred = guidance(**outputs)
|
||||
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
|
||||
guidance.cleanup_models(self.transformer)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -558,8 +565,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
|
||||
negative_prompt_embeds = [
|
||||
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
|
||||
]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
Reference in New Issue
Block a user