diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 0f309625ae..9db646af09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -28,6 +28,8 @@ class StableDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, @@ -45,6 +47,9 @@ class StableDiffusionPipeline(DiffusionPipeline): else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + self.unet.to(torch_device) self.vae.to(torch_device) self.text_encoder.to(torch_device) @@ -72,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline): # get the intial random noise latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(torch_device)