mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
address review comments
This commit is contained in:
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanI2VLoopBeforeDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -72,7 +72,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock):
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"concatenated_latents",
|
||||
"latent_model_inputs",
|
||||
type_hint=torch.Tensor,
|
||||
description="The concatenated noisy and conditioning latents to use for the denoising process.",
|
||||
),
|
||||
@@ -80,7 +80,7 @@ class WanI2VLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
|
||||
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
|
||||
return components, block_state
|
||||
|
||||
|
||||
@@ -215,13 +215,13 @@ class WanI2VLoopDenoiser(PipelineBlock):
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"concatenated_latents",
|
||||
"latent_model_inputs",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process.",
|
||||
),
|
||||
InputParam(
|
||||
"encoder_hidden_states_image",
|
||||
"image_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder hidden states for the image inputs.",
|
||||
@@ -272,10 +272,10 @@ class WanI2VLoopDenoiser(PipelineBlock):
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
|
||||
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
|
||||
timestep=t.flatten(),
|
||||
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
|
||||
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
|
||||
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -296,7 +296,7 @@ class WanImageEncoderStep(PipelineBlock):
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"encoder_hidden_states_image",
|
||||
"image_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
description="image embeddings used to guide the image generation",
|
||||
),
|
||||
@@ -335,7 +335,7 @@ class WanImageEncoderStep(PipelineBlock):
|
||||
if block_state.last_image is not None:
|
||||
image = [block_state.image, block_state.last_image]
|
||||
|
||||
block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device)
|
||||
block_state.image_embeds = self.encode_image(components, image, block_state.device)
|
||||
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
Reference in New Issue
Block a user