1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[StableDiffusionimg2img] validating input type (#1913)

* [StableDiffusionimg2img] validating input type

* fixing tests

* running make style

* make fix-copies
This commit is contained in:
Shubhamai
2023-01-05 03:01:00 -08:00
committed by GitHub
parent f6f4176294
commit 7101c7316b
3 changed files with 15 additions and 0 deletions

View File

@@ -424,6 +424,11 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return timesteps, num_inference_steps - t_start
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt

View File

@@ -343,6 +343,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt

View File

@@ -434,6 +434,11 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return timesteps, num_inference_steps - t_start
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt