From 05a36d5c1a17d0a3dbc7a1efb83ba1bdbfbf1fb2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 24 Nov 2022 20:33:52 +0100 Subject: [PATCH] Upscaling fixed (#1402) * Upscaling fixed * up * more fixes * fix * more fixes * finish again * up --- src/diffusers/pipeline_utils.py | 20 ++-- .../alt_diffusion/pipeline_alt_diffusion.py | 11 +- .../pipeline_alt_diffusion_img2img.py | 1 + .../pipeline_latent_diffusion.py | 9 +- .../pipeline_flax_stable_diffusion.py | 20 ++-- .../pipeline_onnx_stable_diffusion.py | 12 +- .../pipeline_onnx_stable_diffusion_inpaint.py | 16 ++- ...ne_onnx_stable_diffusion_inpaint_legacy.py | 7 +- .../pipeline_stable_diffusion.py | 11 +- ...peline_stable_diffusion_image_variation.py | 11 +- .../pipeline_stable_diffusion_img2img.py | 1 + .../pipeline_stable_diffusion_inpaint.py | 15 ++- ...ipeline_stable_diffusion_inpaint_legacy.py | 7 +- .../pipeline_stable_diffusion_safe.py | 11 +- .../pipeline_versatile_diffusion.py | 13 ++- ...ipeline_versatile_diffusion_dual_guided.py | 11 +- ...ine_versatile_diffusion_image_variation.py | 11 +- ...eline_versatile_diffusion_text_to_image.py | 11 +- .../altdiffusion/test_alt_diffusion.py | 8 +- .../latent_diffusion/test_latent_diffusion.py | 4 +- .../stable_diffusion/test_stable_diffusion.py | 104 +++++++++++++++--- .../test_stable_diffusion_image_variation.py | 5 +- .../test_stable_diffusion_inpaint.py | 15 +-- .../test_stable_diffusion_inpaint_legacy.py | 6 +- .../test_safe_diffusion.py | 4 +- tests/test_pipelines.py | 2 +- 26 files changed, 226 insertions(+), 120 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 224d6a7f8e..d2c5516220 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -544,7 +544,14 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {**init_kwargs, **passed_pipe_kwargs} # remove `null` components - init_dict = {k: v for k, v in init_dict.items() if v[0] is not None} + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} if len(unused_kwargs) > 0: logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") @@ -560,12 +567,11 @@ class DiffusionPipeline(ConfigMixin): is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None - sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module and passed_class_obj[name] is not None: + if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] @@ -581,12 +587,6 @@ class DiffusionPipeline(ConfigMixin): f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) - elif passed_class_obj[name] is None and name not in pipeline_class._optional_components: - logger.warning( - f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is not recommended." - ) - sub_model_should_be_defined = False else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" @@ -608,7 +608,7 @@ class DiffusionPipeline(ConfigMixin): importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - if loaded_sub_model is None and sub_model_should_be_defined: + if loaded_sub_model is None: load_method_name = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 718b6b652a..1a11bfa454 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -141,6 +141,7 @@ class AltDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_xformers_memory_efficient_attention(self): @@ -379,7 +380,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -420,9 +421,9 @@ class AltDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -469,8 +470,8 @@ class AltDiffusionPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index ba4462bc20..2d81a42554 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -154,6 +154,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index a52f19ca31..0e903cb836 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -60,6 +60,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( @@ -79,9 +80,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -107,8 +108,8 @@ class LDMTextToImagePipeline(DiffusionPipeline): generated images. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 8c2ed06e49..e682030c89 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -106,6 +106,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): @@ -168,8 +169,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): neg_prompt_ids: jnp.array = None, ): # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -192,7 +193,12 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: @@ -269,9 +275,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -307,8 +313,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if jit: images = _p_generate( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 92abf6bf77..71f5dae034 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -108,6 +108,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): @@ -206,8 +207,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): **kwargs, ): # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 @@ -241,7 +242,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): # get the initial random noise unless the user supplied it latents_dtype = text_embeddings.dtype - latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 1ceb74fb4f..94da6bdbb2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -158,6 +158,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt @@ -273,9 +274,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -321,8 +322,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 @@ -358,7 +359,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ) num_channels_latents = NUM_LATENT_CHANNELS - latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_shape = ( + batch_size * num_images_per_prompt, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) latents_dtype = text_embeddings.dtype if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 03c3e5397b..55631f0a58 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -27,11 +27,11 @@ def preprocess(image): return 2.0 * image - 1.0 -def preprocess_mask(mask): +def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -143,6 +143,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt @@ -349,7 +350,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): # preprocess mask if not isinstance(mask_image, np.ndarray): - mask_image = preprocess_mask(mask_image) + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) mask_image = mask_image.astype(latents_dtype) mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d88f6f31d3..28acc7fcbd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -140,6 +140,7 @@ class StableDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_xformers_memory_efficient_attention(self): @@ -378,7 +379,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -419,9 +420,9 @@ class StableDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -468,8 +469,8 @@ class StableDiffusionPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index debffb5a60..2a351e5665 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -108,6 +108,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention @@ -281,7 +282,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -324,9 +325,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPFeatureExtractor` - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -370,8 +371,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index fbc779a36f..97b0c20eb2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -153,6 +153,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 24c94572b1..e768108d46 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -218,6 +218,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing @@ -468,7 +469,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -490,7 +491,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) @@ -547,9 +550,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -596,8 +599,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index e9c53e554a..d28e2bef5a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -51,11 +51,11 @@ def preprocess_image(image): return 2.0 * image - 1.0 -def preprocess_mask(mask): +def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -166,6 +166,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing @@ -541,7 +542,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): init_image = preprocess_image(init_image) if not isinstance(mask_image, torch.FloatTensor): - mask_image = preprocess_mask(mask_image) + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 592542cd73..c5adb16895 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -136,6 +136,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): feature_extractor=feature_extractor, ) self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @property @@ -443,7 +444,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -531,9 +532,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -600,8 +601,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * 8 - width = width or self.unet.config.sample_size * 8 + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index c60a8836ec..7be7f4d3ae 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -78,6 +78,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): vae=vae, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -131,9 +132,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -247,9 +248,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -360,9 +361,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index e5dc59389f..2e150ca897 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -87,6 +87,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): vae=vae, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention @@ -419,7 +420,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -474,9 +475,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -552,8 +553,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * 8 - width = width or self.image_unet.config.sample_size * 8 + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 53c67d8c2e..2594757062 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -71,6 +71,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): vae=vae, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet def enable_xformers_memory_efficient_attention(self): @@ -277,7 +278,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -318,9 +319,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -392,8 +393,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * 8 - width = width or self.image_unet.config.sample_size * 8 + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index f38b604bd6..4845d5cab5 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -75,6 +75,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): vae=vae, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if self.text_unet is not None: self._swap_unet_attention_blocks() @@ -337,7 +338,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -378,9 +379,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8): + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -444,8 +445,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * 8 - width = width or self.image_unet.config.sample_size * 8 + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index bcbadfb3db..91fe764449 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -172,7 +172,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]) + expected_slice = np.array( + [0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -219,7 +221,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]) + expected_slice = np.array( + [0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index 085cdb4e76..9d5c07809d 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -111,8 +111,8 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897]) + assert image.shape == (1, 16, 16, 3) + expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 53f8024f6c..b63aeefce5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -207,11 +207,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] - print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) + expected_slice = np.array( + [ + 0.5643956661224365, + 0.6017904281616211, + 0.4799129366874695, + 0.5267305374145508, + 0.5584856271743774, + 0.46413588523864746, + 0.5159522294998169, + 0.4963662028312683, + 0.47919973731040955, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -256,12 +267,13 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): num_inference_steps=2, output_type="np", ) + sd_pipe.enable_attention_slicing() image = output.images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 134, 134, 3) - expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557]) + assert image.shape == (1, 536, 536, 3) + expected_slice = np.array([0.5445, 0.8108, 0.6242, 0.4863, 0.5779, 0.5423, 0.4749, 0.4589, 0.4616]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -303,11 +315,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] - print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) + expected_slice = np.array( + [ + 0.5094760060310364, + 0.5674174427986145, + 0.46675148606300354, + 0.5125715136528015, + 0.5696930289268494, + 0.4674668312072754, + 0.5277683734893799, + 0.4964486062526703, + 0.494540274143219, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -370,11 +393,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] - print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + expected_slice = np.array( + [ + 0.47082293033599854, + 0.5371589064598083, + 0.4562119245529175, + 0.5220914483070374, + 0.5733777284622192, + 0.4795039892196655, + 0.5465868711471558, + 0.5074326395988464, + 0.5042197108268738, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -416,11 +450,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] - print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + expected_slice = np.array( + [ + 0.4707113206386566, + 0.5372191071510315, + 0.4563021957874298, + 0.5220003724098206, + 0.5734264850616455, + 0.4794946610927582, + 0.5463782548904419, + 0.5074145197868347, + 0.504422664642334, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -462,11 +507,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): )[0] image_slice = image[0, -3:, -3:, -1] - print(", ".join(image_slice.flatten().tolist())) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + expected_slice = np.array( + [ + 0.47082313895225525, + 0.5371587872505188, + 0.4562119245529175, + 0.5220913887023926, + 0.5733776688575745, + 0.47950395941734314, + 0.546586811542511, + 0.5074326992034912, + 0.5042197108268738, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -539,7 +595,19 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719]) + expected_slice = np.array( + [ + 0.5108221173286438, + 0.5688379406929016, + 0.4685141146183014, + 0.5098261833190918, + 0.5657756328582764, + 0.4631010890007019, + 0.5226285457611084, + 0.49129390716552734, + 0.4899061322212219, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_num_images_per_prompt(self): @@ -700,18 +768,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): output = sd_pipe(prompt, number_of_steps=2, output_type="np") image_shape = output.images[0].shape[:2] - assert image_shape == [32, 32] + assert image_shape == (64, 64) - output = sd_pipe(prompt, number_of_steps=2, height=64, width=64, output_type="np") + output = sd_pipe(prompt, number_of_steps=2, height=96, width=96, output_type="np") image_shape = output.images[0].shape[:2] - assert image_shape == [64, 64] + assert image_shape == (96, 96) config = dict(sd_pipe.unet.config) config["sample_size"] = 96 - sd_pipe.unet = UNet2DConditionModel.from_config(config) + sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device) output = sd_pipe(prompt, number_of_steps=2, output_type="np") image_shape = output.images[0].shape[:2] - assert image_shape == [96, 96] + assert image_shape == (192, 192) @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index a992308922..9b350d42e1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -154,11 +154,10 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte )[0] image_slice = image[0, -3:, -3:, -1] - print(image_slice.flatten()) image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4935, 0.4784, 0.4802, 0.5027, 0.4805, 0.5149, 0.5143, 0.4879, 0.4731]) + expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 @@ -197,7 +196,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte image_slice = image[-1, -3:, -3:, -1] assert image.shape == (2, 64, 64, 3) - expected_slice = np.array([0.4939, 0.4627, 0.4831, 0.5710, 0.5387, 0.4428, 0.5230, 0.5545, 0.4586]) + expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_img_variation_num_images_per_prompt(self): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 0d8abe5394..e85fae939e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -167,8 +167,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipeline( @@ -213,7 +213,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776]) + expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -226,8 +227,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipeline( @@ -268,8 +269,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # put models in fp16 unet = unet.half() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index 94106b6ba8..4b972c7b7d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -168,7 +168,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( @@ -227,7 +227,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( @@ -273,7 +273,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index d80275c257..dbb9914793 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -156,7 +156,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) + expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -202,7 +202,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) + expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 7e0304972e..a1bee29696 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -367,7 +367,7 @@ class PipelineFastTests(unittest.TestCase): image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk inpaint = StableDiffusionInpaintPipelineLegacy(