diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 246f2b8720..f9458b6bf8 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -390,8 +390,8 @@ class AltDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -411,9 +411,9 @@ class AltDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -459,6 +459,9 @@ class AltDiffusionPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index feb5b00d74..a52f19ca31 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -65,8 +65,8 @@ class LDMTextToImagePipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: Optional[int] = 256, - width: Optional[int] = 256, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, @@ -79,9 +79,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 256): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 256): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -106,6 +106,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 if isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 9c668d5e51..8c2ed06e49 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -160,13 +160,17 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, neg_prompt_ids: jnp.array = None, ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -249,8 +253,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: jnp.array = None, return_dict: bool = True, @@ -265,9 +269,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -302,6 +306,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 + if jit: images = _p_generate( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 9830ace6a1..f64bd4340b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -172,8 +172,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -187,6 +187,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): callback_steps: Optional[int] = 1, **kwargs, ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index b933c52bf6..bbb193a767 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -236,8 +236,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -265,9 +265,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -312,6 +312,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index fbfac6b5a0..5f3cc41b65 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -389,8 +389,8 @@ class StableDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -410,9 +410,9 @@ class StableDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -458,6 +458,9 @@ class StableDiffusionPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 4cfa5817af..f23161c51c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -292,8 +292,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, @@ -315,9 +315,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPFeatureExtractor` - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -360,6 +360,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 9eb8de2482..544f758398 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -509,8 +509,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -538,9 +538,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -586,6 +586,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 # 1. Check inputs self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index cfa71b9242..1be95418d3 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -495,8 +495,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -521,9 +521,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -589,6 +589,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * 8 + width = width or self.unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 1280419c34..c60a8836ec 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -111,8 +111,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): def image_variation( self, image: Union[torch.FloatTensor, PIL.Image.Image], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -131,9 +131,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -193,7 +193,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> image = pipe(image, generator=generator).images[0] + >>> image = pipe.image_variation(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` @@ -227,8 +227,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): def text_to_image( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -247,9 +247,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -341,8 +341,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, @@ -360,9 +360,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index ad4e8b0d0a..e5dc59389f 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -454,8 +454,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, @@ -474,9 +474,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -551,6 +551,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * 8 + width = width or self.image_unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 652b7b735a..53c67d8c2e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -297,8 +297,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -318,9 +318,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -391,6 +391,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * 8 + width = width or self.image_unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index d07d734a64..f38b604bd6 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -357,8 +357,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -378,9 +378,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -443,6 +443,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * 8 + width = width or self.image_unet.config.sample_size * 8 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index b743d100ce..bcbadfb3db 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -171,10 +171,8 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array( - [0.49249017, 0.46064827, 0.4790093, 0.50883967, 0.4811985, 0.51540506, 0.5084924, 0.4860553, 0.47318557] - ) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -220,10 +218,8 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array( - [0.4786532, 0.45791715, 0.47507674, 0.50763345, 0.48375353, 0.515062, 0.51244247, 0.48673993, 0.47105807] - ) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -259,7 +255,7 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 17a293e605..53f8024f6c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -207,9 +207,10 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] + print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -302,9 +303,10 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] + print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -368,9 +370,10 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] + print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -413,9 +416,10 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] + print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -458,9 +462,10 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] + print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -533,7 +538,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = output.images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -563,13 +568,13 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # test num_images_per_prompt=1 (default) images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images - assert images.shape == (1, 128, 128, 3) + assert images.shape == (1, 64, 64, 3) # test num_images_per_prompt=1 (default) for batch of prompts batch_size = 2 images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images - assert images.shape == (batch_size, 128, 128, 3) + assert images.shape == (batch_size, 64, 64, 3) # test num_images_per_prompt for single prompt num_images_per_prompt = 2 @@ -577,7 +582,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt ).images - assert images.shape == (num_images_per_prompt, 128, 128, 3) + assert images.shape == (num_images_per_prompt, 64, 64, 3) # test num_images_per_prompt for batch of prompts batch_size = 2 @@ -585,7 +590,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt ).images - assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3) + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") def test_stable_diffusion_fp16(self): @@ -618,7 +623,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) def test_stable_diffusion_long_prompt(self): unet = self.dummy_cond_unet @@ -671,6 +676,43 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert cap_logger.out.count("@") == 25 assert cap_logger_3.out == "" + def test_stable_diffusion_height_width_opt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + + output = sd_pipe(prompt, number_of_steps=2, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == [32, 32] + + output = sd_pipe(prompt, number_of_steps=2, height=64, width=64, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == [64, 64] + + config = dict(sd_pipe.unet.config) + config["sample_size"] = 96 + sd_pipe.unet = UNet2DConditionModel.from_config(config) + output = sd_pipe(prompt, number_of_steps=2, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == [96, 96] + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 2935275d0f..a992308922 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -157,7 +157,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte print(image_slice.flatten()) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.4935, 0.4784, 0.4802, 0.5027, 0.4805, 0.5149, 0.5143, 0.4879, 0.4731]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 @@ -196,7 +196,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte image_slice = image[-1, -3:, -3:, -1] - assert image.shape == (2, 128, 128, 3) + assert image.shape == (2, 64, 64, 3) expected_slice = np.array([0.4939, 0.4627, 0.4831, 0.5710, 0.5387, 0.4428, 0.5230, 0.5545, 0.4586]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -228,7 +228,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte output_type="np", ).images - assert images.shape == (1, 128, 128, 3) + assert images.shape == (1, 64, 64, 3) # test num_images_per_prompt=1 (default) for batch of images batch_size = 2 @@ -238,7 +238,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte output_type="np", ).images - assert images.shape == (batch_size, 128, 128, 3) + assert images.shape == (batch_size, 64, 64, 3) # test num_images_per_prompt for single prompt num_images_per_prompt = 2 @@ -249,7 +249,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte num_images_per_prompt=num_images_per_prompt, ).images - assert images.shape == (num_images_per_prompt, 128, 128, 3) + assert images.shape == (num_images_per_prompt, 64, 64, 3) # test num_images_per_prompt for batch of prompts batch_size = 2 @@ -260,7 +260,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte num_images_per_prompt=num_images_per_prompt, ).images - assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3) + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") def test_stable_diffusion_img_variation_fp16(self): @@ -297,7 +297,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte output_type="np", ).images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 2f9348c5b5..0d8abe5394 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -212,7 +212,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -300,7 +300,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test mask_image=mask_image, ).images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) @slow diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index dcb3f27303..d80275c257 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -155,7 +155,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -201,7 +201,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -258,7 +258,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) @slow diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 19493e3231..f4b81b54c8 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -436,7 +436,7 @@ class PipelineFastTests(unittest.TestCase): assert image_inpaint.shape == (1, 32, 32, 3) assert image_img2img.shape == (1, 32, 32, 3) - assert image_text2img.shape == (1, 128, 128, 3) + assert image_text2img.shape == (1, 64, 64, 3) def test_set_scheduler(self): unet = self.dummy_cond_unet diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index 72316aad92..9b9dcddd60 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -78,7 +78,7 @@ class FlaxPipelineTests(unittest.TestCase): images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (num_samples, 1, 128, 128, 3) + assert images.shape == (num_samples, 1, 64, 64, 3) if jax.device_count() == 8: assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3 assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1