mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix a bit more, remove print lines
This commit is contained in:
@@ -482,6 +482,7 @@ else:
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
"HunyuanVideo15Pipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -1168,6 +1169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideo15Pipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
@@ -140,10 +140,7 @@ class HunyuanVideo15AttnProcessor2_0:
|
||||
|
||||
|
||||
batch_size, seq_len, heads, dim = query.shape
|
||||
print(f" query.shape: {query.shape}")
|
||||
print(f" attention_mask.shape: {attention_mask.shape}")
|
||||
attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True)
|
||||
print(f" attention_mask.shape: {attention_mask.shape}")
|
||||
attention_mask = attention_mask.bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
@@ -160,8 +157,6 @@ class HunyuanVideo15AttnProcessor2_0:
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
print(f" hidden_states.shape: {hidden_states.shape}")
|
||||
print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}")
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -407,14 +402,8 @@ class HunyuanVideoTokenRefiner(nn.Module):
|
||||
pooled_projections = pooled_projections.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}")
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
print(f" temb: {temb.shape}, {temb[0,:10]}")
|
||||
print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}")
|
||||
print(f" -> token_refiner")
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -537,11 +526,9 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}")
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
@@ -550,8 +537,6 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}")
|
||||
print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}")
|
||||
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
@@ -570,8 +555,6 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}")
|
||||
print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
@@ -791,31 +774,23 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# qwen text embedding
|
||||
print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
print(f" timestep: {timestep}, {timestep[:10]}")
|
||||
print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}")
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
encoder_hidden_states_cond_emb = self.cond_type_embed(
|
||||
torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb
|
||||
print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
# byt5 text embedding
|
||||
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
|
||||
print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
|
||||
|
||||
encoder_hidden_states_2_cond_emb = self.cond_type_embed(
|
||||
torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
|
||||
print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
|
||||
|
||||
# image embed
|
||||
encoder_hidden_states_3 = self.image_embedder(image_embeds)
|
||||
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
is_t2v = torch.all(image_embeds == 0)
|
||||
if is_t2v:
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0
|
||||
@@ -824,8 +799,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin
|
||||
dtype=encoder_attention_mask.dtype,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}")
|
||||
else:
|
||||
encoder_attention_mask_3 = torch.ones(
|
||||
(batch_size, encoder_hidden_states_3.shape[1]),
|
||||
@@ -840,9 +813,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin
|
||||
)
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb
|
||||
|
||||
print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
|
||||
|
||||
# reorder and combine text tokens: combine valid tokens first, then padding
|
||||
encoder_attention_mask = encoder_attention_mask.bool()
|
||||
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
||||
@@ -891,12 +861,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin
|
||||
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
||||
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
||||
|
||||
print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}")
|
||||
print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}")
|
||||
print(f" temb.shape: {temb.shape}, {temb[0,:10]}")
|
||||
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -242,6 +242,7 @@ else:
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
]
|
||||
_import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"]
|
||||
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
@@ -662,6 +663,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
)
|
||||
from .hunyuan_video1_5 import HunyuanVideo15Pipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .i2vgen_xl import I2VGenXLPipeline
|
||||
from .kandinsky import (
|
||||
|
||||
@@ -27,6 +27,7 @@ from .image_processor import HunyuanVideo15ImageProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HunyuanVideo15PipelineOutput
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -225,7 +226,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
|
||||
self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
|
||||
self.vision_states_dim = self.transformer.config.vision_states_dim if getattr(self, "transformer", None) else 729
|
||||
self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
|
||||
# fmt: off
|
||||
self.system_message ="You are a helpful assistant. Describe the video by detailing the following aspects: \
|
||||
1. The main content and theme of the video. \
|
||||
@@ -236,8 +237,9 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
# fmt: on
|
||||
self.prompt_template_encode_start_idx = 108
|
||||
self.tokenizer_max_length = 1000
|
||||
self.text_encoder_2_max_length = 256
|
||||
self.tokenizer_2_max_length = 256
|
||||
self.vision_num_semantic_tokens = 729
|
||||
self.default_aspect_ratio = (16, 9) # (width: height)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -282,7 +284,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=False,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
@@ -521,7 +523,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_cond_latents_and_mask(self, latents):
|
||||
def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]):
|
||||
"""
|
||||
Prepare conditional latents and mask for t2v generation.
|
||||
|
||||
@@ -535,13 +537,14 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
|
||||
cond_latents_concat = torch.zeros(
|
||||
batch, channels, frames, height, width,
|
||||
device=latents.device,
|
||||
dtype=latents.dtype
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
mask_concat = torch.zeros(
|
||||
batch, 1, frames, height, width,
|
||||
device=latents.device
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
return cond_latents_concat, mask_concat
|
||||
@@ -702,7 +705,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if height is None and width is None:
|
||||
height, width = self.video_processor.calculate_default_height_width(height, width, self.target_size)
|
||||
height, width = self.video_processor.calculate_default_height_width(self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
@@ -761,8 +764,19 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents)
|
||||
vision_states = torch.zeros(batch_size, self.vision_num_semantic_tokens, self.vision_states_dim).to(latents.device)
|
||||
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, torch.float32, device)
|
||||
image_embeds = torch.zeros(
|
||||
batch_size,
|
||||
self.vision_num_semantic_tokens,
|
||||
self.vision_states_dim,
|
||||
dtype=torch.float32,
|
||||
device=device
|
||||
)
|
||||
|
||||
image_embeds = image_embeds.to(self.transformer.dtype)
|
||||
latents=latents.to(self.transformer.dtype)
|
||||
cond_latents_concat=cond_latents_concat.to(self.transformer.dtype)
|
||||
mask_concat=mask_concat.to(self.transformer.dtype)
|
||||
|
||||
|
||||
# 7. Denoising loop
|
||||
@@ -817,8 +831,8 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
image_embeds=vision_states,
|
||||
hidden_states=latent_model_input,
|
||||
image_embeds=image_embeds,
|
||||
timestep=timestep,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -863,9 +877,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
self.vae.enable_tiling()
|
||||
video = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
self.vae.disable_tiling()
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
Reference in New Issue
Block a user