1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

address review comments

This commit is contained in:
Aryan
2025-08-01 06:49:06 +02:00
parent 22f3273a82
commit 2ff808d26c
2 changed files with 9 additions and 9 deletions

View File

@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanI2VLoopBeforeDenoiser(PipelineBlock):
model_name = "stable-diffusion-xl"
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -72,7 +72,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock):
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"concatenated_latents",
"latent_model_inputs",
type_hint=torch.Tensor,
description="The concatenated noisy and conditioning latents to use for the denoising process.",
),
@@ -80,7 +80,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock):
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
return components, block_state
@@ -215,13 +215,13 @@ class WanI2VLoopDenoiser(PipelineBlock):
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"concatenated_latents",
"latent_model_inputs",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"encoder_hidden_states_image",
"image_embeds",
required=True,
type_hint=torch.Tensor,
description="The encoder hidden states for the image inputs.",
@@ -272,10 +272,10 @@ class WanI2VLoopDenoiser(PipelineBlock):
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -296,7 +296,7 @@ class WanImageEncoderStep(PipelineBlock):
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"encoder_hidden_states_image",
"image_embeds",
type_hint=torch.Tensor,
description="image embeddings used to guide the image generation",
),
@@ -335,7 +335,7 @@ class WanImageEncoderStep(PipelineBlock):
if block_state.last_image is not None:
image = [block_state.image, block_state.last_image]
block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device)
block_state.image_embeds = self.encode_image(components, image, block_state.device)
# Add outputs
self.set_block_state(state, block_state)