1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

guiders support for wan

This commit is contained in:
Aryan
2025-04-05 00:09:09 +02:00
parent 357f4f056b
commit 74e34e5f69
4 changed files with 123 additions and 24 deletions

View File

@@ -30,7 +30,7 @@ from ..models.transformers.transformer_hunyuan_video import (
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock
@dataclass
@@ -186,6 +186,14 @@ def _register_guidance_metadata():
),
)
# Wan
GuidanceMetadataRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
),
)
# fmt: off
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):

View File

@@ -467,3 +467,85 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
return (output,)
return Transformer2DModelOutput(sample=output)
### ===== Custom attention processors for guidance methods ===== ###
class WanPAGAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
is_encoder_hidden_states_provided = encoder_hidden_states is not None
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
if is_encoder_hidden_states_provided:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
else:
# Perturbed attention applied only when self-attention
hidden_states = value
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

View File

@@ -617,7 +617,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[c] for c in conds]
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):

View File

@@ -21,6 +21,7 @@ import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -380,6 +381,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
guidance: Optional[GuidanceMixin] = None,
):
r"""
The call function to the pipeline for generation.
@@ -444,6 +446,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -519,37 +525,38 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
conds = [prompt_embeds, negative_prompt_embeds]
prompt_embeds, negative_prompt_embeds = [[c] for c in conds]
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guidance.prepare_models(self.transformer)
latents, prompt_embeds = guidance.prepare_inputs(
latents, (prompt_embeds[0], negative_prompt_embeds[0])
)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
noise_uncond = self.transformer(
hidden_states=latent_model_input,
for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)):
cc.mark_state(f"batch_{batch_index}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
hidden_states=latent,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=condition,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
guidance.prepare_outputs(noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
outputs = guidance.outputs
noise_pred = guidance(**outputs)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -558,8 +565,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
negative_prompt_embeds = [
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):