1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-05-21 00:52:15 +02:00
parent 4b14ddd80b
commit 694dcf2cfd
2 changed files with 50 additions and 25 deletions

View File

@@ -106,35 +106,38 @@ class WanVACETransformerBlock(nn.Module):
) -> torch.Tensor:
if self.proj_in is not None:
control_hidden_states = self.proj_in(control_hidden_states)
hidden_states = hidden_states + control_hidden_states
else:
hidden_states = control_hidden_states
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
control_hidden_states
)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
control_hidden_states = control_hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
control_hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
control_hidden_states
)
conditioning_states = None
if self.proj_out is not None:
control_hidden_states = self.proj_out(hidden_states)
conditioning_states = self.proj_out(control_hidden_states)
return hidden_states, control_hidden_states
return conditioning_states, control_hidden_states
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
@@ -309,11 +312,9 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
# 2. Patch embedding
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
print("hidden_states", hidden_states.shape)
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
print("control_hidden_states", control_hidden_states.shape)
control_hidden_states_padding = control_hidden_states.new_zeros(
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
)
@@ -333,12 +334,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Prepare VACE hints
control_hidden_states_list = []
vace_hidden_states = hidden_states
for i, block in enumerate(self.vace_blocks):
vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func(
block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
)
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
control_hidden_states_list = control_hidden_states_list[::-1]
for i, block in enumerate(self.blocks):
@@ -351,12 +351,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
else:
# Prepare VACE hints
control_hidden_states_list = []
vace_hidden_states = hidden_states
for i, block in enumerate(self.vace_blocks):
vace_hidden_states, control_hidden_states = block(
vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
conditioning_states, control_hidden_states = block(
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
)
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
control_hidden_states_list = control_hidden_states_list[::-1]
for i, block in enumerate(self.blocks):

View File

@@ -23,7 +23,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...models import AutoencoderKLWan, WanVACETransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
@@ -137,7 +137,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
):
@@ -421,6 +421,13 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
)
ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
if any(l != ref_images_lengths[0] for l in ref_images_lengths):
raise ValueError(
f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
"may be added in the future."
)
reference_images_preprocessed = []
for i, reference_images_batch in enumerate(reference_images):
preprocessed_images = []
@@ -449,7 +456,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
mask: torch.Tensor,
reference_images: Optional[List[List[torch.Tensor]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
device = device or self._execution_device
if isinstance(generator, list):
# TODO: support this
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
@@ -473,8 +483,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
vae_dtype = self.vae.dtype
video = video.to(dtype=vae_dtype)
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
1, self.vae.config.z_dim, 1, 1, 1
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
1, self.vae.config.z_dim, 1, 1, 1
)
if mask is None:
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
else:
mask = mask.to(dtype=vae_dtype)
mask = torch.where(mask > 0.5, 1.0, 0.0)
@@ -482,6 +500,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reactive = video * mask
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
latents = torch.cat([inactive, reactive], dim=1)
latent_list = []
@@ -491,6 +511,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_image = reference_image.to(dtype=vae_dtype)
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
latent_list.append(latent)
@@ -790,7 +811,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device,
)
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator)
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
mask = self.prepare_masks(mask, reference_images, generator)
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)
@@ -808,6 +829,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents,
)
if conditioning_latents.shape[2] != latents.shape[2]:
logger.warning(
"The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)