mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor
This commit is contained in:
@@ -106,35 +106,38 @@ class WanVACETransformerBlock(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if self.proj_in is not None:
|
||||
control_hidden_states = self.proj_in(control_hidden_states)
|
||||
hidden_states = hidden_states + control_hidden_states
|
||||
else:
|
||||
hidden_states = control_hidden_states
|
||||
control_hidden_states = control_hidden_states + hidden_states
|
||||
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
control_hidden_states = control_hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
|
||||
conditioning_states = None
|
||||
if self.proj_out is not None:
|
||||
control_hidden_states = self.proj_out(hidden_states)
|
||||
conditioning_states = self.proj_out(control_hidden_states)
|
||||
|
||||
return hidden_states, control_hidden_states
|
||||
return conditioning_states, control_hidden_states
|
||||
|
||||
|
||||
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
@@ -309,11 +312,9 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
print("hidden_states", hidden_states.shape)
|
||||
|
||||
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
|
||||
print("control_hidden_states", control_hidden_states.shape)
|
||||
control_hidden_states_padding = control_hidden_states.new_zeros(
|
||||
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
|
||||
)
|
||||
@@ -333,12 +334,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
vace_hidden_states = hidden_states
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func(
|
||||
block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
@@ -351,12 +351,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
else:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
vace_hidden_states = hidden_states
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
vace_hidden_states, control_hidden_states = block(
|
||||
vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
conditioning_states, control_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...models import AutoencoderKLWan, WanVACETransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -137,7 +137,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
@@ -421,6 +421,13 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
|
||||
if any(l != ref_images_lengths[0] for l in ref_images_lengths):
|
||||
raise ValueError(
|
||||
f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
|
||||
"may be added in the future."
|
||||
)
|
||||
|
||||
reference_images_preprocessed = []
|
||||
for i, reference_images_batch in enumerate(reference_images):
|
||||
preprocessed_images = []
|
||||
@@ -449,7 +456,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[List[torch.Tensor]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
|
||||
if isinstance(generator, list):
|
||||
# TODO: support this
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
@@ -473,8 +483,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
vae_dtype = self.vae.dtype
|
||||
video = video.to(dtype=vae_dtype)
|
||||
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
|
||||
if mask is None:
|
||||
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
|
||||
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
else:
|
||||
mask = mask.to(dtype=vae_dtype)
|
||||
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
||||
@@ -482,6 +500,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
reactive = video * mask
|
||||
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
|
||||
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
|
||||
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
latents = torch.cat([inactive, reactive], dim=1)
|
||||
|
||||
latent_list = []
|
||||
@@ -491,6 +511,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
reference_image = reference_image.to(dtype=vae_dtype)
|
||||
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
|
||||
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
|
||||
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
|
||||
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
|
||||
latent_list.append(latent)
|
||||
@@ -790,7 +811,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device,
|
||||
)
|
||||
|
||||
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator)
|
||||
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
|
||||
mask = self.prepare_masks(mask, reference_images, generator)
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
@@ -808,6 +829,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents,
|
||||
)
|
||||
|
||||
if conditioning_latents.shape[2] != latents.shape[2]:
|
||||
logger.warning(
|
||||
"The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
Reference in New Issue
Block a user