1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

WanI2V encode_image (#11164)

* WanI2V encode_image
This commit is contained in:
hlky
2025-03-28 13:35:34 +01:00
committed by GitHub
parent de6a88c2d7
commit 5d970a4aa9

View File

@@ -220,8 +220,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
return prompt_embeds
def encode_image(self, image: PipelineImageInput):
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
def encode_image(
self,
image: PipelineImageInput,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
image = self.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = self.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
@@ -587,7 +592,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
image_embeds = self.encode_image(image)
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)