1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Upscaling fixed (#1402)

* Upscaling fixed

* up

* more fixes

* fix

* more fixes

* finish again

* up
This commit is contained in:
Patrick von Platen
2022-11-24 20:33:52 +01:00
committed by GitHub
parent cbfed0c256
commit 05a36d5c1a
26 changed files with 226 additions and 120 deletions

View File

@@ -544,7 +544,14 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
@@ -560,12 +567,11 @@ class DiffusionPipeline(ConfigMixin):
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module and passed_class_obj[name] is not None:
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
@@ -581,12 +587,6 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
logger.warning(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -608,7 +608,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None and sub_model_should_be_defined:
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):

View File

@@ -141,6 +141,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
@@ -379,7 +380,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -420,9 +421,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -469,8 +470,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)

View File

@@ -154,6 +154,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):

View File

@@ -60,6 +60,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
@torch.no_grad()
def __call__(
@@ -79,9 +80,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -107,8 +108,8 @@ class LDMTextToImagePipeline(DiffusionPipeline):
generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1

View File

@@ -106,6 +106,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
@@ -168,8 +169,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -192,7 +193,12 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
@@ -269,9 +275,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -307,8 +313,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if jit:
images = _p_generate(

View File

@@ -108,6 +108,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -206,8 +207,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
**kwargs,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
@@ -241,7 +242,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:

View File

@@ -158,6 +158,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -273,9 +274,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -321,8 +322,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
@@ -358,7 +359,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
)
num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)

View File

@@ -27,11 +27,11 @@ def preprocess(image):
return 2.0 * image - 1.0
def preprocess_mask(mask):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -143,6 +143,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -349,7 +350,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)

View File

@@ -140,6 +140,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
@@ -378,7 +379,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -419,9 +420,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -468,8 +469,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)

View File

@@ -108,6 +108,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
@@ -281,7 +282,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -324,9 +325,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -370,8 +371,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)

View File

@@ -153,6 +153,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing

View File

@@ -218,6 +218,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
@@ -468,7 +469,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -490,7 +491,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
@@ -547,9 +550,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -596,8 +599,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps)

View File

@@ -51,11 +51,11 @@ def preprocess_image(image):
return 2.0 * image - 1.0
def preprocess_mask(mask):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -166,6 +166,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
@@ -541,7 +542,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_image = preprocess_image(init_image)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)

View File

@@ -136,6 +136,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
feature_extractor=feature_extractor,
)
self._safety_text_concept = safety_concept
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
@@ -443,7 +444,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -531,9 +532,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -600,8 +601,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)

View File

@@ -78,6 +78,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
@@ -131,9 +132,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -247,9 +248,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -360,9 +361,9 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the

View File

@@ -87,6 +87,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if self.text_unet is not None and (
"dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
@@ -419,7 +420,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -474,9 +475,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -552,8 +553,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * 8
width = width or self.image_unet.config.sample_size * 8
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, image, height, width, callback_steps)

View File

@@ -71,6 +71,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
@@ -277,7 +278,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -318,9 +319,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -392,8 +393,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * 8
width = width or self.image_unet.config.sample_size * 8
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)

View File

@@ -75,6 +75,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if self.text_unet is not None:
self._swap_unet_attention_blocks()
@@ -337,7 +338,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -378,9 +379,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -444,8 +445,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * 8
width = width or self.image_unet.config.sample_size * 8
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)

View File

@@ -172,7 +172,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093])
expected_slice = np.array(
[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -219,7 +221,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237])
expected_slice = np.array(
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -111,8 +111,8 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -207,11 +207,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)[0]
image_slice = image[0, -3:, -3:, -1]
print(", ".join(image_slice.flatten().tolist()))
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
expected_slice = np.array(
[
0.5643956661224365,
0.6017904281616211,
0.4799129366874695,
0.5267305374145508,
0.5584856271743774,
0.46413588523864746,
0.5159522294998169,
0.4963662028312683,
0.47919973731040955,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -256,12 +267,13 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_inference_steps=2,
output_type="np",
)
sd_pipe.enable_attention_slicing()
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 134, 134, 3)
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])
assert image.shape == (1, 536, 536, 3)
expected_slice = np.array([0.5445, 0.8108, 0.6242, 0.4863, 0.5779, 0.5423, 0.4749, 0.4589, 0.4616])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -303,11 +315,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)[0]
image_slice = image[0, -3:, -3:, -1]
print(", ".join(image_slice.flatten().tolist()))
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
expected_slice = np.array(
[
0.5094760060310364,
0.5674174427986145,
0.46675148606300354,
0.5125715136528015,
0.5696930289268494,
0.4674668312072754,
0.5277683734893799,
0.4964486062526703,
0.494540274143219,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -370,11 +393,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)[0]
image_slice = image[0, -3:, -3:, -1]
print(", ".join(image_slice.flatten().tolist()))
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
expected_slice = np.array(
[
0.47082293033599854,
0.5371589064598083,
0.4562119245529175,
0.5220914483070374,
0.5733777284622192,
0.4795039892196655,
0.5465868711471558,
0.5074326395988464,
0.5042197108268738,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -416,11 +450,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)[0]
image_slice = image[0, -3:, -3:, -1]
print(", ".join(image_slice.flatten().tolist()))
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
expected_slice = np.array(
[
0.4707113206386566,
0.5372191071510315,
0.4563021957874298,
0.5220003724098206,
0.5734264850616455,
0.4794946610927582,
0.5463782548904419,
0.5074145197868347,
0.504422664642334,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -462,11 +507,22 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)[0]
image_slice = image[0, -3:, -3:, -1]
print(", ".join(image_slice.flatten().tolist()))
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
expected_slice = np.array(
[
0.47082313895225525,
0.5371587872505188,
0.4562119245529175,
0.5220913887023926,
0.5733776688575745,
0.47950395941734314,
0.546586811542511,
0.5074326992034912,
0.5042197108268738,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -539,7 +595,19 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
expected_slice = np.array(
[
0.5108221173286438,
0.5688379406929016,
0.4685141146183014,
0.5098261833190918,
0.5657756328582764,
0.4631010890007019,
0.5226285457611084,
0.49129390716552734,
0.4899061322212219,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_num_images_per_prompt(self):
@@ -700,18 +768,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
output = sd_pipe(prompt, number_of_steps=2, output_type="np")
image_shape = output.images[0].shape[:2]
assert image_shape == [32, 32]
assert image_shape == (64, 64)
output = sd_pipe(prompt, number_of_steps=2, height=64, width=64, output_type="np")
output = sd_pipe(prompt, number_of_steps=2, height=96, width=96, output_type="np")
image_shape = output.images[0].shape[:2]
assert image_shape == [64, 64]
assert image_shape == (96, 96)
config = dict(sd_pipe.unet.config)
config["sample_size"] = 96
sd_pipe.unet = UNet2DConditionModel.from_config(config)
sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device)
output = sd_pipe(prompt, number_of_steps=2, output_type="np")
image_shape = output.images[0].shape[:2]
assert image_shape == [96, 96]
assert image_shape == (192, 192)
@slow

View File

@@ -154,11 +154,10 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
)[0]
image_slice = image[0, -3:, -3:, -1]
print(image_slice.flatten())
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4935, 0.4784, 0.4802, 0.5027, 0.4805, 0.5149, 0.5143, 0.4879, 0.4731])
expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3
@@ -197,7 +196,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 64, 64, 3)
expected_slice = np.array([0.4939, 0.4627, 0.4831, 0.5710, 0.5387, 0.4428, 0.5230, 0.5545, 0.4586])
expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_num_images_per_prompt(self):

View File

@@ -167,8 +167,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
@@ -213,7 +213,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776])
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -226,8 +227,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
@@ -268,8 +269,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
# put models in fp16
unet = unet.half()

View File

@@ -168,7 +168,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipelineLegacy(
@@ -227,7 +227,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipelineLegacy(
@@ -273,7 +273,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipelineLegacy(

View File

@@ -156,7 +156,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@@ -202,7 +202,7 @@ class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -367,7 +367,7 @@ class PipelineFastTests(unittest.TestCase):
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
# make sure here that pndm scheduler skips prk
inpaint = StableDiffusionInpaintPipelineLegacy(