mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[WIP]Vae preprocessor refactor (PR1) (#3557)
VaeImageProcessor.preprocess refactor * refactored VaeImageProcessor - allow passing optional height and width argument to resize() - add convert_to_rgb * refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again * added a test in test_pipelines_common.py to test latents as image inputs * refactored img2img pipelines that accept latents as image: - controlnet img2img, stable diffusion img2img , instruct_pix2pix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
||||
`height` and `width` arguments from `preprocess` method
|
||||
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
||||
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
|
||||
factor.
|
||||
@@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image to [-1,1]
|
||||
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to RGB format.
|
||||
"""
|
||||
|
||||
config_name = CONFIG_NAME
|
||||
@@ -49,11 +52,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
vae_scale_factor: int = 8,
|
||||
resample: str = "lanczos",
|
||||
do_normalize: bool = True,
|
||||
do_convert_rgb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
@@ -69,7 +73,19 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pt(images):
|
||||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Convert a PIL image or a list of PIL images to numpy arrays.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
||||
images = np.stack(images, axis=0)
|
||||
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert a numpy image to a pytorch tensor
|
||||
"""
|
||||
@@ -80,7 +96,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def pt_to_numpy(images):
|
||||
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a pytorch tensor to a numpy image
|
||||
"""
|
||||
@@ -101,18 +117,39 @@ class VaeImageProcessor(ConfigMixin):
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Converts an image to RGB format.
|
||||
"""
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
|
||||
"""
|
||||
w, h = images.size
|
||||
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
|
||||
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
return images
|
||||
if height is None:
|
||||
height = image.height
|
||||
if width is None:
|
||||
width = image.width
|
||||
|
||||
width, height = (
|
||||
x - x % self.config.vae_scale_factor for x in (width, height)
|
||||
) # resize to integer multiple of vae_scale_factor
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
|
||||
@@ -126,10 +163,11 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
if self.config.do_resize:
|
||||
image = [self.resize(i) for i in image]
|
||||
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
||||
image = np.stack(image, axis=0) # to np
|
||||
image = [self.resize(i, height, width) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
@@ -146,7 +184,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
_, _, height, width = image.shape
|
||||
_, channel, height, width = image.shape
|
||||
|
||||
# don't need any preprocess if the image is latents
|
||||
if channel == 4:
|
||||
return image
|
||||
|
||||
if self.config.do_resize and (
|
||||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
||||
):
|
||||
|
||||
@@ -69,6 +69,11 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -538,21 +543,26 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective"
|
||||
f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -586,7 +596,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -609,9 +626,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
|
||||
@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
@@ -172,7 +171,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
@@ -477,17 +479,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
@@ -592,21 +589,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_np = isinstance(image, np.ndarray)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
||||
|
||||
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
|
||||
if (
|
||||
not image_is_pil
|
||||
and not image_is_tensor
|
||||
and not image_is_np
|
||||
and not image_is_pil_list
|
||||
and not image_is_tensor_list
|
||||
and not image_is_np_list
|
||||
):
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
|
||||
)
|
||||
|
||||
if image_is_pil:
|
||||
image_batch_size = 1
|
||||
elif image_is_tensor:
|
||||
image_batch_size = image.shape[0]
|
||||
elif image_is_pil_list:
|
||||
image_batch_size = len(image)
|
||||
elif image_is_tensor_list:
|
||||
else:
|
||||
image_batch_size = len(image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -633,29 +635,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -691,31 +671,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
# override DiffusionPipeline
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -733,7 +688,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -760,8 +722,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
@@ -837,15 +799,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
@@ -903,6 +861,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
@@ -922,6 +881,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
height, width = image[0].shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -198,7 +197,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
@@ -503,17 +505,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
@@ -615,24 +612,30 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_np = isinstance(image, np.ndarray)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
||||
|
||||
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
|
||||
if (
|
||||
not image_is_pil
|
||||
and not image_is_tensor
|
||||
and not image_is_np
|
||||
and not image_is_pil_list
|
||||
and not image_is_tensor_list
|
||||
and not image_is_np_list
|
||||
):
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
|
||||
)
|
||||
|
||||
if image_is_pil:
|
||||
image_batch_size = 1
|
||||
elif image_is_tensor:
|
||||
image_batch_size = image.shape[0]
|
||||
elif image_is_pil_list:
|
||||
image_batch_size = len(image)
|
||||
elif image_is_tensor_list:
|
||||
else:
|
||||
image_batch_size = len(image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -660,29 +663,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -720,21 +701,26 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -763,31 +749,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
|
||||
return latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
# override DiffusionPipeline
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -805,9 +766,21 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
control_image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
@@ -836,8 +809,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accpet
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
@@ -914,15 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
control_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
@@ -966,10 +938,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# 4. Prepare image, and controlnet_conditioning_image
|
||||
image = prepare_image(image)
|
||||
# 4. Prepare image
|
||||
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
|
||||
|
||||
# 5. Prepare image
|
||||
# 5. Prepare controlnet_conditioning_image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
control_image = self.prepare_control_image(
|
||||
image=control_image,
|
||||
|
||||
@@ -30,7 +30,6 @@ from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
@@ -316,6 +315,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
@@ -742,24 +744,30 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_np = isinstance(image, np.ndarray)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
||||
|
||||
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
|
||||
if (
|
||||
not image_is_pil
|
||||
and not image_is_tensor
|
||||
and not image_is_np
|
||||
and not image_is_pil_list
|
||||
and not image_is_tensor_list
|
||||
and not image_is_np_list
|
||||
):
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
|
||||
)
|
||||
|
||||
if image_is_pil:
|
||||
image_batch_size = 1
|
||||
elif image_is_tensor:
|
||||
image_batch_size = image.shape[0]
|
||||
elif image_is_pil_list:
|
||||
image_batch_size = len(image)
|
||||
elif image_is_tensor_list:
|
||||
else:
|
||||
image_batch_size = len(image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -787,29 +795,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -983,7 +969,12 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
control_image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -30,6 +31,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -40,6 +40,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -549,21 +554,26 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -599,7 +609,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
source_prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -619,9 +636,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
@@ -699,7 +717,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -33,6 +34,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead"
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -37,6 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -423,21 +428,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -474,6 +484,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
width, height = image[0].size
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
width, height = image[0].shape[:-1]
|
||||
else:
|
||||
height, width = image[0].shape[-2:]
|
||||
|
||||
@@ -512,7 +524,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
depth_map: Optional[torch.FloatTensor] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
@@ -535,9 +554,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
process. Can accept image latents as `image` only if `depth_map` is not `None`.
|
||||
depth_map (`torch.FloatTensor`, *optional*):
|
||||
depth prediction that will be used as additional conditioning for the image generation process. If not
|
||||
defined, it will automatically predicts the depth via `self.depth_estimator`.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
@@ -664,7 +686,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
)
|
||||
|
||||
# 5. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 6. Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -159,6 +159,11 @@ def kl_divergence(hidden_states):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -799,19 +804,25 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if image.shape[1] == 4:
|
||||
latents = image
|
||||
|
||||
if isinstance(generator, list):
|
||||
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = self.vae.config.scaling_factor * latents
|
||||
if isinstance(generator, list):
|
||||
latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
latents = self.vae.config.scaling_factor * latents
|
||||
|
||||
if batch_size != latents.shape[0]:
|
||||
if batch_size % latents.shape[0] == 0:
|
||||
|
||||
@@ -73,6 +73,11 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -441,6 +446,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -455,6 +461,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
@@ -544,21 +551,26 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -592,7 +604,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -615,9 +634,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
|
||||
@@ -43,6 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -145,7 +150,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 7.5,
|
||||
image_guidance_scale: float = 1.5,
|
||||
@@ -168,8 +180,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be repainted according to `prompt`.
|
||||
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also
|
||||
accpet image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -290,8 +303,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
image = preprocess(image)
|
||||
height, width = image.shape[-2:]
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -308,6 +320,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
generator,
|
||||
)
|
||||
|
||||
height, width = image_latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents = self.prepare_latents(
|
||||
@@ -746,17 +762,21 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand image_latents for batch_size
|
||||
|
||||
@@ -94,7 +94,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
@@ -291,7 +291,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -308,7 +315,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image upscaling.
|
||||
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be
|
||||
either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It
|
||||
will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an
|
||||
@@ -413,7 +420,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = self.image_processor.preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
if image.shape[1] == 3:
|
||||
# encode image if not in latent-space yet
|
||||
|
||||
@@ -177,6 +177,11 @@ EXAMPLE_INVERT_DOC_STRING = """
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
@@ -629,7 +634,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
source_embeds,
|
||||
target_embeds,
|
||||
callback_steps,
|
||||
@@ -727,19 +731,25 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if image.shape[1] == 4:
|
||||
latents = image
|
||||
|
||||
if isinstance(generator, list):
|
||||
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = self.vae.config.scaling_factor * latents
|
||||
if isinstance(generator, list):
|
||||
latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
latents = self.vae.config.scaling_factor * latents
|
||||
|
||||
if batch_size != latents.shape[0]:
|
||||
if batch_size % latents.shape[0] == 0:
|
||||
@@ -804,7 +814,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
|
||||
source_embeds: torch.Tensor = None,
|
||||
target_embeds: torch.Tensor = None,
|
||||
height: Optional[int] = None,
|
||||
@@ -905,7 +914,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
source_embeds,
|
||||
target_embeds,
|
||||
callback_steps,
|
||||
@@ -1085,7 +1093,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
def invert(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@@ -1109,8 +1124,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`, *optional*):
|
||||
`Image`, or tensor representing an image batch which will be used for conditioning.
|
||||
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -1179,7 +1195,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
|
||||
@@ -1267,16 +1283,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
inverted_latents = latents.detach().clone()
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents.detach())
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
# 9. Convert to PIL.
|
||||
if output_type == "pil":
|
||||
image = self.image_processor.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (inverted_latents, image)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@@ -34,6 +35,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
"The preprocess method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor.preprocess instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -40,6 +40,7 @@ class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -41,7 +41,9 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
@@ -99,7 +101,8 @@ class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
@@ -51,7 +52,8 @@ class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"})
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -40,6 +40,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
@@ -51,7 +52,8 @@ class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
pipeline_class = StableDiffusionControlNetInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset([])
|
||||
image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -25,7 +25,11 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -42,7 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
|
||||
}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -101,6 +106,7 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
|
||||
@@ -93,6 +93,7 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -47,6 +47,7 @@ class StableDiffusionImageVariationPipelineFastTests(
|
||||
batch_params = IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset([])
|
||||
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -32,7 +32,6 @@ from diffusers import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -91,6 +90,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -160,12 +161,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
@@ -178,12 +177,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
@@ -198,14 +195,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * 2
|
||||
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
@@ -221,12 +216,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset([])
|
||||
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -31,10 +31,15 @@ from diffusers import (
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -47,9 +52,8 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
|
||||
pipeline_class = StableDiffusionInstructPix2PixPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -163,6 +167,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
|
||||
|
||||
image = np.array(inputs["image"]).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
image = image / 2 + 0.5
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
inputs["image"] = image.repeat(2, 1, 1, 1)
|
||||
|
||||
@@ -199,6 +204,28 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
# Overwrite the default test_latents_inputs because pix2pix encode the image differently
|
||||
def test_latents_input(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = StableDiffusionInstructPix2PixPipeline(**components)
|
||||
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
|
||||
|
||||
vae = components["vae"]
|
||||
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
|
||||
|
||||
for image_param in self.image_latents_params:
|
||||
if image_param in inputs.keys():
|
||||
inputs[image_param] = vae.encode(inputs[image_param]).latent_dist.mode()
|
||||
|
||||
out_latents_inputs = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(out - out_latents_inputs).max()
|
||||
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -44,6 +44,7 @@ class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, Pi
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -45,6 +45,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -32,11 +32,16 @@ from diffusers import (
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -45,11 +50,10 @@ enable_full_determinism()
|
||||
@skip_mps
|
||||
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -130,6 +134,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
|
||||
|
||||
def get_dummy_inversion_inputs(self, device, seed=0):
|
||||
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
|
||||
dummy_image = dummy_image / 2 + 0.5
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
@@ -145,6 +150,24 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
inputs = self.get_dummy_inversion_inputs(device, seed)
|
||||
|
||||
if input_image_type == "pt":
|
||||
image = inputs["image"]
|
||||
elif input_image_type == "np":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
elif input_image_type == "pil":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
image = VaeImageProcessor.numpy_to_pil(image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {input_image_type}")
|
||||
|
||||
inputs["image"] = image
|
||||
inputs["output_type"] = output_type
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
@@ -281,6 +304,41 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images
|
||||
output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images
|
||||
output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images
|
||||
|
||||
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
|
||||
|
||||
max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max()
|
||||
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images
|
||||
out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images
|
||||
out_input_pil = sd_pipe.invert(
|
||||
**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil")
|
||||
).images
|
||||
|
||||
max_diff = np.abs(out_input_pt - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
|
||||
assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1)
|
||||
|
||||
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
|
||||
@@ -41,6 +41,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
test_cpu_offload = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
|
||||
@@ -47,6 +47,7 @@ class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTeste
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -45,6 +45,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
# Attend and excite requires being able to run a backward pass at
|
||||
# inference time. There's no deterministic backward operator for pad
|
||||
|
||||
@@ -51,7 +51,12 @@ from diffusers.utils import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -65,9 +70,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -49,6 +49,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -40,6 +40,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -52,6 +52,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, P
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
test_cpu_offload = True
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
# TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false
|
||||
test_xformers_attention = False
|
||||
|
||||
@@ -46,6 +46,7 @@ class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTe
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
embedder_hidden_size = 32
|
||||
|
||||
@@ -8,6 +8,7 @@ import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
@@ -39,9 +40,28 @@ class PipelineLatentTesterMixin:
|
||||
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
|
||||
)
|
||||
|
||||
@property
|
||||
def image_latents_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `image_latents_params` in the child test class. "
|
||||
"`image_latents_params` are tested for if passing latents directly are producing same results"
|
||||
)
|
||||
|
||||
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
inputs = self.get_dummy_inputs(device, seed)
|
||||
|
||||
def convert_to_pt(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
input_image = image
|
||||
elif isinstance(image, np.ndarray):
|
||||
input_image = VaeImageProcessor.numpy_to_pt(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
input_image = VaeImageProcessor.pil_to_numpy(image)
|
||||
input_image = VaeImageProcessor.numpy_to_pt(input_image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {type(image)}")
|
||||
return input_image
|
||||
|
||||
def convert_pt_to_type(image, input_image_type):
|
||||
if input_image_type == "pt":
|
||||
input_image = image
|
||||
@@ -56,21 +76,32 @@ class PipelineLatentTesterMixin:
|
||||
|
||||
for image_param in self.image_params:
|
||||
if image_param in inputs.keys():
|
||||
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type)
|
||||
inputs[image_param] = convert_pt_to_type(
|
||||
convert_to_pt(inputs[image_param]).to(device), input_image_type
|
||||
)
|
||||
|
||||
inputs["output_type"] = output_type
|
||||
|
||||
return inputs
|
||||
|
||||
def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
|
||||
self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff)
|
||||
|
||||
def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0]
|
||||
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0]
|
||||
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0]
|
||||
output_pt = pipe(
|
||||
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt")
|
||||
)[0]
|
||||
output_np = pipe(
|
||||
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np")
|
||||
)[0]
|
||||
output_pil = pipe(
|
||||
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil")
|
||||
)[0]
|
||||
|
||||
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
|
||||
self.assertLess(
|
||||
@@ -98,6 +129,31 @@ class PipelineLatentTesterMixin:
|
||||
max_diff = np.abs(out_input_pil - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
|
||||
def test_latents_input(self):
|
||||
if len(self.image_latents_params) == 0:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
|
||||
|
||||
vae = components["vae"]
|
||||
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
|
||||
generator = inputs["generator"]
|
||||
for image_param in self.image_latents_params:
|
||||
if image_param in inputs.keys():
|
||||
inputs[image_param] = (
|
||||
vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
)
|
||||
out_latents_inputs = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(out - out_latents_inputs).max()
|
||||
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
|
||||
Reference in New Issue
Block a user