From 2ff808d26c97376e2a91c69d4ebbae9383a9473e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 1 Aug 2025 06:49:06 +0200 Subject: [PATCH] address review comments --- src/diffusers/modular_pipelines/wan/denoise.py | 14 +++++++------- src/diffusers/modular_pipelines/wan/encoders.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 2acbe8c8a8..afddb9a1ac 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanI2VLoopBeforeDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" + model_name = "wan" @property def expected_components(self) -> List[ComponentSpec]: @@ -72,7 +72,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock): def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - "concatenated_latents", + "latent_model_inputs", type_hint=torch.Tensor, description="The concatenated noisy and conditioning latents to use for the denoising process.", ), @@ -80,7 +80,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock): @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int): - block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1) + block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1) return components, block_state @@ -215,13 +215,13 @@ class WanI2VLoopDenoiser(PipelineBlock): def intermediate_inputs(self) -> List[str]: return [ InputParam( - "concatenated_latents", + "latent_model_inputs", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process.", ), InputParam( - "encoder_hidden_states_image", + "image_embeds", required=True, type_hint=torch.Tensor, description="The encoder hidden states for the image inputs.", @@ -272,10 +272,10 @@ class WanI2VLoopDenoiser(PipelineBlock): # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.concatenated_latents.to(transformer_dtype), + hidden_states=block_state.latent_model_inputs.to(transformer_dtype), timestep=t.flatten(), encoder_hidden_states=prompt_embeds.to(transformer_dtype), - encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype), + encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index d25e96414d..425fb7c748 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -296,7 +296,7 @@ class WanImageEncoderStep(PipelineBlock): def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - "encoder_hidden_states_image", + "image_embeds", type_hint=torch.Tensor, description="image embeddings used to guide the image generation", ), @@ -335,7 +335,7 @@ class WanImageEncoderStep(PipelineBlock): if block_state.last_image is not None: image = [block_state.image, block_state.last_image] - block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device) + block_state.image_embeds = self.encode_image(components, image, block_state.device) # Add outputs self.set_block_state(state, block_state)