mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Stable Diffusion text-to-image fine-tuning
|
||||
|
||||
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script shows how to fine-tune the stable diffusion model on your own dataset.
|
||||
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -557,6 +557,9 @@ class CrossAttention(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -178,6 +179,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -197,14 +199,33 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -224,6 +245,26 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
@@ -310,6 +351,106 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
|
||||
def check_inputs(self, prompt, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# add noise to latents using the timestep
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
clean_latents = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
return latents, clean_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -384,112 +525,43 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(
|
||||
"At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
|
||||
f" prompts: {prompt}. Please make sure to pass only a single prompt."
|
||||
)
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, strength, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
|
||||
source_text_embeddings = self._encode_prompt(
|
||||
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
|
||||
)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
# 4. Preprocess image
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
# 6. Prepare latent variables
|
||||
latents, clean_latents = self.prepare_latents(
|
||||
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
|
||||
)
|
||||
source_latents = latents
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
clean_latents = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
|
||||
if not (accepts_eta and (0 < eta <= 1)):
|
||||
raise ValueError(
|
||||
"Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
|
||||
f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
|
||||
)
|
||||
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
source_latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
generator = extra_step_kwargs.pop("generator", None)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
@@ -551,22 +623,13 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -298,6 +298,73 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -371,75 +438,45 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# set timesteps and move to the correct device
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -459,22 +496,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
@@ -85,7 +87,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
@@ -139,6 +146,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -158,14 +166,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
@@ -202,6 +212,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
@@ -214,6 +225,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
@@ -306,6 +318,103 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, prompt, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -379,102 +488,40 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, strength, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
# 4. Preprocess image
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
|
||||
)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -495,20 +542,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -139,6 +139,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -158,6 +159,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
@@ -166,6 +168,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
@@ -183,6 +186,26 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -202,24 +225,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
@@ -306,6 +311,106 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
mask = mask.repeat(batch_size, 1, 1, 1)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
return mask, masked_image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -390,83 +495,59 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
# 4. Preprocess mask and image
|
||||
if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image):
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# prepare mask and masked_image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
mask = mask.to(device=device, dtype=text_embeddings.dtype)
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=text_embeddings.dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=text_embeddings.dtype)
|
||||
# 7. Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
@@ -476,27 +557,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
# set timesteps and move to the correct device
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -521,22 +585,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
# 11. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
# 12. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 13. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -19,13 +19,20 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from tqdm.auto import tqdm
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -85,17 +92,26 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
@@ -143,6 +159,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -162,14 +179,53 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -275,6 +331,88 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
|
||||
def check_inputs(self, prompt, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
|
||||
init_image = init_image.to(device=self.device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
return latents, init_latents_orig, noise
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -353,98 +491,49 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, strength, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# preprocess image
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
# 4. Preprocess image and mask
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
|
||||
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
||||
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
# 6. Prepare latent variables
|
||||
# encode the init image into latents and scale the latents
|
||||
latents, init_latents_orig, noise = self.prepare_latents(
|
||||
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
|
||||
)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
# 7. Prepare mask latent
|
||||
mask = mask_image.to(device=self.device, dtype=latents.dtype)
|
||||
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -468,22 +557,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
# 11. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 12. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@@ -479,7 +480,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape.npy"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.npy"
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
@@ -506,7 +507,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (512, 768, 3)
|
||||
# img2img is flaky across GPUs even in fp32, so using MAE here
|
||||
assert np.abs(expected_image - image).mean() < 1e-3
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img2img_pipeline_k_lms(self):
|
||||
init_image = load_image(
|
||||
@@ -515,7 +516,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape_k_lms.npy"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_k_lms.npy"
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
@@ -543,8 +544,44 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 768, 3)
|
||||
# img2img is flaky across GPUs even in fp32, so using MAE here
|
||||
assert np.abs(expected_image - image).mean() < 1e-3
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img2img_pipeline_ddim(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_ddim.npy"
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=ddim,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img2img_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
@@ -387,7 +387,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self):
|
||||
# TODO(Anton, Patrick) - I think we can remove this test soon
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
|
||||
Reference in New Issue
Block a user