mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor
This commit is contained in:
@@ -175,7 +175,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
@@ -273,9 +273,6 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if control_hidden_states is None:
|
||||
raise ValueError("Control hidden states must be provided for VACE models.")
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -299,6 +296,12 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
|
||||
if control_hidden_states_scale is None:
|
||||
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
|
||||
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
|
||||
if len(control_hidden_states_scale) != len(self.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
|
||||
f"equal to {len(self.config.vace_layers)}."
|
||||
)
|
||||
|
||||
# 1. Rotary position embedding
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
@@ -306,9 +309,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
print("hidden_states", hidden_states.shape)
|
||||
|
||||
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
|
||||
print("control_hidden_states", control_hidden_states.shape)
|
||||
control_hidden_states_padding = control_hidden_states.new_zeros(
|
||||
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
|
||||
)
|
||||
@@ -329,11 +334,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
vace_hidden_states = hidden_states
|
||||
for block in self.vace_blocks:
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func(
|
||||
block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append(control_hidden_states)
|
||||
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
@@ -341,24 +346,24 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * control_hidden_states_scale[i]
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
else:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
vace_hidden_states = hidden_states
|
||||
for block in self.vace_blocks:
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
vace_hidden_states, control_hidden_states = block(
|
||||
vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append(control_hidden_states)
|
||||
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * control_hidden_states_scale[i]
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
@@ -292,7 +292,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
@@ -368,39 +368,78 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
video = self.video_processor.preprocess_video(video, None, None) # Use the height/width of video
|
||||
image_size = tuple(video.shape[-2:])
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
scale = min(width / video_width, height / video_height)
|
||||
video_height, video_width = int(video_height * scale), int(video_width * scale)
|
||||
|
||||
if video_height % base != 0 or video_width % base != 0:
|
||||
logger.warning(
|
||||
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
|
||||
)
|
||||
video_height = (video_height // base) * base
|
||||
video_width = (video_width // base) * base
|
||||
|
||||
assert video_height * video_width <= height * width
|
||||
|
||||
video = self.video_processor.preprocess_video(video, video_height, video_width)
|
||||
image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
|
||||
else:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=dtype, device=device)
|
||||
video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device)
|
||||
image_size = (height, width) # Use the height/width provider by user
|
||||
|
||||
if mask is not None:
|
||||
mask = self.video_processor.preprocess_video(mask, height, width)
|
||||
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
|
||||
else:
|
||||
mask = torch.ones_like(video, dtype=dtype, device=device)
|
||||
mask = torch.ones_like(video)
|
||||
|
||||
video = video.to(dtype=dtype, device=device)
|
||||
mask = mask.to(dtype=dtype, device=device)
|
||||
|
||||
reference_images_preprocessed = []
|
||||
if reference_images is not None:
|
||||
if not isinstance(reference_images, list):
|
||||
reference_images = [reference_images]
|
||||
for i, image in enumerate(reference_images):
|
||||
image = self.video_processor.preprocess(image, None, None) # Use the height/width of image
|
||||
# Make a list of list of images where the outer list corresponds to video batch size and the inner list
|
||||
# corresponds to list of conditioning images per video
|
||||
if reference_images is None or isinstance(reference_images, PIL.Image.Image):
|
||||
reference_images = [[reference_images] for _ in range(video.shape[0])]
|
||||
elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image):
|
||||
reference_images = [reference_images]
|
||||
elif (
|
||||
isinstance(reference_images, (list, tuple))
|
||||
and isinstance(next(iter(reference_images)), list)
|
||||
and isinstance(next(iter(reference_images[0])), PIL.Image.Image)
|
||||
):
|
||||
reference_images = reference_images
|
||||
else:
|
||||
raise ValueError(
|
||||
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
|
||||
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
|
||||
)
|
||||
|
||||
if video.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
reference_images_preprocessed = []
|
||||
for i, reference_images_batch in enumerate(reference_images):
|
||||
preprocessed_images = []
|
||||
for j, image in enumerate(reference_images_batch):
|
||||
if image is None:
|
||||
continue
|
||||
image = self.video_processor.preprocess(image, None, None)
|
||||
img_height, img_width = image.shape[-2:]
|
||||
scale = min(image_size[0] / img_height, image_size[1] / img_width)
|
||||
new_height, new_width = int(img_height * scale), int(img_width * scale)
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
image.unsqueeze(1), size=(new_height, new_width), mode="bilinear", align_corners=False
|
||||
).squeeze(1)
|
||||
|
||||
image, size=(new_height, new_width), mode="bilinear", align_corners=False
|
||||
).squeeze(0) # [C, H, W]
|
||||
top = (image_size[0] - new_height) // 2
|
||||
left = (image_size[1] - new_width) // 2
|
||||
canvas = torch.ones(batch_size, 1, 3, *image_size, device=device, dtype=dtype)
|
||||
canvas[:, :, :, top : top + new_height, left : left + new_width] = resized_image
|
||||
reference_images_preprocessed.append(canvas)
|
||||
canvas = torch.ones(3, *image_size, device=device, dtype=dtype)
|
||||
canvas[:, top : top + new_height, left : left + new_width] = resized_image
|
||||
preprocessed_images.append(canvas)
|
||||
reference_images_preprocessed.append(preprocessed_images)
|
||||
|
||||
return video, mask, reference_images_preprocessed
|
||||
|
||||
@@ -408,7 +447,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[torch.Tensor]] = None,
|
||||
reference_images: Optional[List[List[torch.Tensor]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list):
|
||||
@@ -416,7 +455,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
|
||||
if reference_images is None:
|
||||
# For each batch of video, we set no reference image (as one or more can be passed by user)
|
||||
# For each batch of video, we set no re
|
||||
# ference image (as one or more can be passed by user)
|
||||
reference_images = [[None] for _ in range(video.shape[0])]
|
||||
else:
|
||||
if video.shape[0] != len(reference_images):
|
||||
@@ -437,22 +477,24 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
|
||||
else:
|
||||
mask = mask.to(dtype=vae_dtype)
|
||||
mask = [torch.where(m > 0.5, 1.0, 0.0) for m in mask]
|
||||
inactive = [v * (1 - m) for v, m in zip(video, mask)]
|
||||
reactive = [v * m for v, m in zip(video, mask)]
|
||||
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
||||
inactive = video * (1 - mask)
|
||||
reactive = video * mask
|
||||
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
|
||||
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
|
||||
latents = [torch.cat([i, r], dim=0) for i, r in zip(inactive, reactive)]
|
||||
latents = torch.cat([inactive, reactive], dim=1)
|
||||
|
||||
latent_list = []
|
||||
for latent, ref_images in zip(latents, reference_images):
|
||||
if ref_images is not None:
|
||||
ref_images = ref_images.to(dtype=vae_dtype)
|
||||
ref_latents = retrieve_latents(self.vae.encode(ref_images), generator, sample_mode="argmax")
|
||||
ref_latents = [torch.cat([r, torch.zeros_like(r)], dim=0) for r in ref_latents]
|
||||
latent = torch.cat([*ref_latents, latent], dim=1)
|
||||
for latent, reference_images_batch in zip(latents, reference_images):
|
||||
for reference_image in reference_images_batch:
|
||||
assert reference_image.ndim == 3
|
||||
reference_image = reference_image.to(dtype=vae_dtype)
|
||||
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
|
||||
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
|
||||
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
|
||||
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
|
||||
latent_list.append(latent)
|
||||
return latent_list
|
||||
return torch.stack(latent_list)
|
||||
|
||||
def prepare_masks(
|
||||
self,
|
||||
@@ -479,25 +521,28 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
|
||||
mask_list = []
|
||||
transformer_patch_size = self.transformer.config.patch_size
|
||||
for mask_, ref_images in zip(mask, reference_images):
|
||||
num_frames, num_channels, height, width = mask_.shape
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
num_channels, num_frames, height, width = mask_.shape
|
||||
new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal
|
||||
new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
mask_ = mask_[:, 0, :, :]
|
||||
mask_ = mask_.view(num_frames, height, self.vae_scale_factor_spatial, width, self.vae_scale_factor_spatial)
|
||||
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(2, 4).flatten(0, 1)
|
||||
mask_ = mask_[0, :, :, :]
|
||||
mask_ = mask_.view(
|
||||
num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial
|
||||
)
|
||||
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width]
|
||||
mask_ = torch.nn.functional.interpolate(
|
||||
mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact"
|
||||
).squeeze(0)
|
||||
if ref_images is not None:
|
||||
num_ref_images = ref_images.size(0)
|
||||
mask_padding = torch.zeros_like(mask[:num_ref_images, :, :, :])
|
||||
num_ref_images = len(reference_images_batch)
|
||||
if num_ref_images > 0:
|
||||
mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
|
||||
mask_ = torch.cat([mask_, mask_padding], dim=1)
|
||||
mask_list.append(mask_)
|
||||
return mask_list
|
||||
return torch.stack(mask_list)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -746,12 +791,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator)
|
||||
conditioning_latents = [c.to(transformer_dtype) for c in conditioning_latents]
|
||||
|
||||
mask = self.prepare_masks(mask, reference_images, generator)
|
||||
mask = [m.to(transformer_dtype) for m in mask]
|
||||
|
||||
conditioning_latents = [torch.cat([c, m], dim=1) for c, m in zip(conditioning_latents, mask)]
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
|
||||
Reference in New Issue
Block a user