From 589330595dfa82b2a8db22094427ab10978d2de7 Mon Sep 17 00:00:00 2001 From: Lukas Struppek Date: Mon, 12 Dec 2022 13:45:27 +0100 Subject: [PATCH] VersatileDiffusion: fix input processing (#1568) * fix versatile diffusion input * merge main * `make fix-copies` Co-authored-by: anton- --- .../paint_by_example/pipeline_paint_by_example.py | 3 ++- .../pipeline_stable_diffusion_image_variation.py | 3 ++- ...ipeline_versatile_diffusion_image_variation.py | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 18d024a749..cae3c4febc 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline): and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 1d34280d34..71222f4afb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 87924fdff8..7419d2f37d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = [p for p in prompt] + batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings @@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): - if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): - raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")