diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb9929cf2c..eaaac1838d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -482,6 +482,7 @@ else: "HunyuanVideoFramepackPipeline", "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", + "HunyuanVideo15Pipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -1168,6 +1169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, + HunyuanVideo15Pipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 86c3a3565f..c26b43e19c 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -140,10 +140,7 @@ class HunyuanVideo15AttnProcessor2_0: batch_size, seq_len, heads, dim = query.shape - print(f" query.shape: {query.shape}") - print(f" attention_mask.shape: {attention_mask.shape}") attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) - print(f" attention_mask.shape: {attention_mask.shape}") attention_mask = attention_mask.bool() self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) @@ -160,8 +157,6 @@ class HunyuanVideo15AttnProcessor2_0: backend=self._attention_backend, parallel_config=self._parallel_config, ) - print(f" hidden_states.shape: {hidden_states.shape}") - print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}") hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -407,14 +402,8 @@ class HunyuanVideoTokenRefiner(nn.Module): pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) - print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}") hidden_states = self.proj_in(hidden_states) - print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}") - print(f" temb: {temb.shape}, {temb[0,:10]}") - print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}") - print(f" -> token_refiner") hidden_states = self.token_refiner(hidden_states, temb, attention_mask) - print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}") return hidden_states @@ -537,11 +526,9 @@ class HunyuanVideoTransformerBlock(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}") norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}") # 2. Joint attention attn_output, context_attn_output = self.attn( @@ -550,8 +537,6 @@ class HunyuanVideoTransformerBlock(nn.Module): attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) - print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}") - print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}") # 3. Modulation and residual connection @@ -570,8 +555,6 @@ class HunyuanVideoTransformerBlock(nn.Module): hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}") - print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") return hidden_states, encoder_hidden_states @@ -791,31 +774,23 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin hidden_states = self.x_embedder(hidden_states) # qwen text embedding - print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") - print(f" timestep: {timestep}, {timestep[:10]}") - print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}") encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) - print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") encoder_hidden_states_cond_emb = self.cond_type_embed( torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long) ) encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb - print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") # byt5 text embedding encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2) - print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") encoder_hidden_states_2_cond_emb = self.cond_type_embed( torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long) ) encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb - print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") # image embed encoder_hidden_states_3 = self.image_embedder(image_embeds) - print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") is_t2v = torch.all(image_embeds == 0) if is_t2v: encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 @@ -824,8 +799,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device, ) - print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") - print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}") else: encoder_attention_mask_3 = torch.ones( (batch_size, encoder_hidden_states_3.shape[1]), @@ -840,9 +813,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin ) encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb - print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") - - # reorder and combine text tokens: combine valid tokens first, then padding encoder_attention_mask = encoder_attention_mask.bool() encoder_attention_mask_2 = encoder_attention_mask_2.bool() @@ -891,12 +861,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin encoder_hidden_states = torch.stack(new_encoder_hidden_states) encoder_attention_mask = torch.stack(new_encoder_attention_mask) - print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}") - print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") - print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}") - print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}") - print(f" temb.shape: {temb.shape}, {temb[0,:10]}") - # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 69bb14b98e..fe84f5c7ca 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -242,6 +242,7 @@ else: "HunyuanVideoImageToVideoPipeline", "HunyuanVideoFramepackPipeline", ] + _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", @@ -662,6 +663,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) + from .hunyuan_video1_5 import HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 1a2aa51792..378f557023 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -27,6 +27,7 @@ from .image_processor import HunyuanVideo15ImageProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import HunyuanVideo15PipelineOutput from ...guiders import ClassifierFreeGuidance +from ...utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -225,7 +226,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 - self.vision_states_dim = self.transformer.config.vision_states_dim if getattr(self, "transformer", None) else 729 + self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 # fmt: off self.system_message ="You are a helpful assistant. Describe the video by detailing the following aspects: \ 1. The main content and theme of the video. \ @@ -236,8 +237,9 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): # fmt: on self.prompt_template_encode_start_idx = 108 self.tokenizer_max_length = 1000 - self.text_encoder_2_max_length = 256 + self.tokenizer_2_max_length = 256 self.vision_num_semantic_tokens = 729 + self.default_aspect_ratio = (16, 9) # (width: height) @staticmethod @@ -282,7 +284,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): prompt_embeds = text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, - output_hidden_states=False, + output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] prompt_embeds = prompt_embeds.to(dtype=dtype) @@ -521,7 +523,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): return latents - def prepare_cond_latents_and_mask(self, latents): + def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]): """ Prepare conditional latents and mask for t2v generation. @@ -535,13 +537,14 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): cond_latents_concat = torch.zeros( batch, channels, frames, height, width, - device=latents.device, - dtype=latents.dtype + dtype=dtype, + device=device ) mask_concat = torch.zeros( batch, 1, frames, height, width, - device=latents.device + dtype=dtype, + device=device ) return cond_latents_concat, mask_concat @@ -702,7 +705,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): ) if height is None and width is None: - height, width = self.video_processor.calculate_default_height_width(height, width, self.target_size) + height, width = self.video_processor.calculate_default_height_width(self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size) self._attention_kwargs = attention_kwargs self._current_timestep = None @@ -761,8 +764,19 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): generator, latents, ) - cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents) - vision_states = torch.zeros(batch_size, self.vision_num_semantic_tokens, self.vision_states_dim).to(latents.device) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, torch.float32, device) + image_embeds = torch.zeros( + batch_size, + self.vision_num_semantic_tokens, + self.vision_states_dim, + dtype=torch.float32, + device=device + ) + + image_embeds = image_embeds.to(self.transformer.dtype) + latents=latents.to(self.transformer.dtype) + cond_latents_concat=cond_latents_concat.to(self.transformer.dtype) + mask_concat=mask_concat.to(self.transformer.dtype) # 7. Denoising loop @@ -817,8 +831,8 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): with self.transformer.cache_context(context_name): # Run denoiser and store noise prediction in this batch guider_state_batch.noise_pred = self.transformer( - hidden_states=latents, - image_embeds=vision_states, + hidden_states=latent_model_input, + image_embeds=image_embeds, timestep=timestep, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -863,9 +877,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor - self.vae.enable_tiling() video = self.vae.decode(latents, return_dict=False, generator=generator)[0] - self.vae.disable_tiling() video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents