mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve pipelines more
This commit is contained in:
@@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again."
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -160,8 +160,8 @@ else:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
|
||||
@@ -89,7 +89,7 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
|
||||
@@ -257,7 +257,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth)
|
||||
num_transformer_blocks = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
)
|
||||
else:
|
||||
num_transformer_blocks = 1
|
||||
|
||||
@@ -270,7 +274,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
@@ -280,7 +284,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
@@ -775,7 +781,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
if key.startswith(prefix):
|
||||
text_model_dict[key[len(prefix + "."):]] = checkpoint[key]
|
||||
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
@@ -787,7 +793,7 @@ textenc_conversion_lst = [
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
("text_projection", "text_projection.weight")
|
||||
("text_projection", "text_projection.weight"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
@@ -876,7 +882,9 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280)
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -892,13 +900,13 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
for key in keys:
|
||||
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
# continue
|
||||
if key[len(prefix):] in textenc_conversion_map:
|
||||
if key[len(prefix) :] in textenc_conversion_map:
|
||||
if key.endswith("text_projection"):
|
||||
value = checkpoint[key].T
|
||||
else:
|
||||
value = checkpoint[key]
|
||||
|
||||
text_model_dict[textenc_conversion_map[key[len(prefix):]]] = value
|
||||
text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
|
||||
|
||||
if key.startswith(prefix + "transformer."):
|
||||
new_key = key[len(prefix + "transformer.") :]
|
||||
@@ -1122,10 +1130,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
|
||||
@@ -24,7 +24,12 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
@@ -13,20 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -35,7 +37,6 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -115,6 +116,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
@@ -148,7 +150,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
# feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
@@ -305,7 +306,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
@@ -327,14 +330,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
@@ -354,10 +355,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
|
||||
zero_out_negative_prompt = negative_prompt is None and self.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
@@ -405,7 +408,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -415,11 +420,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -518,6 +524,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score):
|
||||
add_time_ids = [list(original_size + crops_coords_top_left + target_size)]
|
||||
|
||||
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.unet.config.cross_attention_dim
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `unet.config.cross_attention_dim`.")
|
||||
|
||||
add_time_ids = torch.tensor(add_time_ids, dtype=torch.long)
|
||||
return add_time_ids
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -652,7 +670,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt(
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -684,11 +707,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long)
|
||||
add_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
@@ -699,6 +720,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
@@ -753,11 +776,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
|
||||
if not output_type == "latent":
|
||||
# CHECK there is problem here (PVP)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -13,22 +13,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from pytorch_lightning import seed_everything
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -37,7 +40,6 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -117,6 +119,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
@@ -307,7 +311,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
@@ -329,14 +335,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
@@ -358,10 +362,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
|
||||
zero_out_negative_prompt = negative_prompt is None and self.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
@@ -410,7 +416,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -420,10 +428,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -502,6 +512,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
# TODO(Patrick): Make sure to remove +1 later here - that's just to compare with CompVis
|
||||
# t_start = max(num_inference_steps - init_timestep, 0) + 1
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
@@ -521,7 +533,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
@@ -553,7 +564,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
seed_everything(0)
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
print("noise", noise.abs().sum())
|
||||
print("image", init_latents.abs().sum())
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -561,6 +575,28 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score):
|
||||
if self.requires_aesthetics_score:
|
||||
add_time_ids = [list(original_size + crops_coords_top_left + (aesthetic_score,))]
|
||||
add_neg_time_ids = [list(original_size + crops_coords_top_left + (negative_aesthetic_score,))]
|
||||
else:
|
||||
add_time_ids = [list(original_size + crops_coords_top_left + target_size)]
|
||||
add_neg_time_ids = [list(original_size + crops_coords_top_left + target_size)]
|
||||
|
||||
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.unet.config.cross_attention_dim
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim:
|
||||
raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.")
|
||||
elif expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim:
|
||||
raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.")
|
||||
elif expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `unet.config.cross_attention_dim`.")
|
||||
|
||||
add_time_ids = torch.tensor(add_time_ids, dtype=torch.long)
|
||||
add_neg_time_ids = torch.tensor(add_neg_time_ids, dtype=torch.long)
|
||||
|
||||
return add_time_ids, add_neg_time_ids
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
@@ -575,7 +611,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.5,
|
||||
strength: float = 0.3,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -705,7 +741,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt(
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -719,7 +760,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
@@ -729,30 +770,23 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# 7. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 8. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
|
||||
if self.unet.add_embedding.linear_1.in_features == (1280 + 5 * 256):
|
||||
# refiner
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (aesthetic_score,))], dtype=torch.long)
|
||||
neg_add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (negative_aesthetic_score,))], dtype=torch.long)
|
||||
elif self.unet.add_embedding.linear_1.in_features == (1280 + 6 * 256):
|
||||
# SD-XL Base
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long)
|
||||
neg_add_time_ids = add_time_ids.clone()
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, neg_add_time_ids], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(num_images_per_prompt, 1)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -811,7 +845,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
if not output_type == "latent":
|
||||
# CHECK there is problem here (PVP)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -183,7 +183,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
sample = sample / ((sigma **2 + 1) ** 0.5)
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
@@ -202,12 +202,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace" and "leading" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
||||
::-1
|
||||
].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
|
||||
Reference in New Issue
Block a user