diff --git a/examples/dreambooth/train_dreambooth_sd_xl_lora.py b/examples/dreambooth/train_dreambooth_sd_xl_lora.py index 89e7e5272f..4092a2e0e0 100644 --- a/examples/dreambooth/train_dreambooth_sd_xl_lora.py +++ b/examples/dreambooth/train_dreambooth_sd_xl_lora.py @@ -16,7 +16,6 @@ import argparse import gc import hashlib -import itertools import logging import math import os @@ -47,7 +46,6 @@ from diffusers import ( DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, - StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin @@ -60,7 +58,7 @@ from diffusers.models.attention_processor import ( SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -90,8 +88,8 @@ license: creativeml-openrail-m base_model: {base_model} instance_prompt: {prompt} tags: -- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} -- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- 'stable-diffusion-xl' +- 'stable-diffusion-xl-diffusers' - text-to-image - diffusers - lora @@ -110,10 +108,12 @@ LoRA for the text encoder was enabled: {train_text_encoder}. f.write(yaml + model_card) -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, - subfolder="text_encoder", + subfolder=subfolder, revision=revision, ) model_class = text_encoder_config.architectures[0] @@ -122,14 +122,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from transformers import CLIPTextModel return CLIPTextModel - elif model_class == "RobertaSeriesModelWithTransformation": - from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection - return RobertaSeriesModelWithTransformation - elif model_class == "T5EncoderModel": - from transformers import T5EncoderModel - - return T5EncoderModel + return CLIPTextModelWithProjection else: raise ValueError(f"{model_class} is not supported.") @@ -150,12 +146,6 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) parser.add_argument( "--instance_data_dir", type=str, @@ -405,37 +395,6 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument( - "--pre_compute_text_embeddings", - action="store_true", - help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", - ) - parser.add_argument( - "--tokenizer_max_length", - type=int, - default=None, - required=False, - help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", - ) - parser.add_argument( - "--text_encoder_use_attention_mask", - action="store_true", - required=False, - help="Whether to use attention mask for the text encoder", - ) - parser.add_argument( - "--validation_images", - required=False, - default=None, - nargs="+", - help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", - ) - parser.add_argument( - "--class_labels_conditioning", - required=False, - default=None, - help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", - ) if input_args is not None: args = parser.parse_args(input_args) @@ -557,8 +516,8 @@ def collate_fn(examples, with_prior_preservation=False): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - unet_added_conditions = [example["instance_added_cond_kwargs"] for example in examples] - + add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples] + add_time_ids = [example["instance_added_cond_kwargs"]["add_time_ids"] for example in examples] if has_attention_mask: attention_mask = [example["instance_attention_mask"] for example in examples] @@ -567,7 +526,9 @@ def collate_fn(examples, with_prior_preservation=False): if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] - unet_added_conditions += [example["class_added_cond_kwargs"] for example in examples] + add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples] + add_time_ids += [example["class_added_cond_kwargs"]["add_time_ids"] for example in examples] + if has_attention_mask: attention_mask += [example["class_attention_mask"] for example in examples] @@ -576,7 +537,11 @@ def collate_fn(examples, with_prior_preservation=False): input_ids = torch.cat(input_ids, dim=0) - batch = {"input_ids": input_ids, "pixel_values": pixel_values, "unet_added_conditions": unet_added_conditions} + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } if has_attention_mask: batch["attention_mask"] = attention_mask @@ -658,14 +623,8 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. if args.train_text_encoder: raise NotImplementedError("Text encoder training not yet supported.") - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -742,50 +701,45 @@ def main(args): ).repo_id # Load the tokenizer - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) - elif args.pretrained_model_name_or_path: - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=args.revision, - use_fast=False, - ) + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) - # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder_one = text_encoder_cls.from_pretrained( + text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - text_encoder_two = text_encoder_cls.from_pretrained( + text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) - try: - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision - ) - except OSError: - # IF does not have a VAE so let's just set it to None - # We don't have to error out here - vae = None + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # We only train the additional adapter LoRA layers - if vae is not None: - vae.requires_grad_(False) - text_encoder.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) unet.requires_grad_(False) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision @@ -798,9 +752,9 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) - if vae is not None: - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -854,49 +808,17 @@ def main(args): unet.set_attn_processor(unet_lora_attn_procs) unet_lora_layers = AttnProcsLayers(unet.attn_processors) - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, - # we first load a dummy pipeline with the text encoder and then do the monkey-patching. - text_encoder_lora_layers = None - if args.train_text_encoder: - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_proj.out_features, cross_attention_dim=None - ) - text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, text_encoder=text_encoder - ) - temp_pipeline._modify_text_encoder(text_lora_attn_procs) - text_encoder = temp_pipeline.text_encoder - del temp_pipeline - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None - text_encoder_lora_layers_to_save = None - - if args.train_text_encoder: - text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() - unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + accelerator.unwrap_model(unet_lora_layers).state_dict().keys() for model in models: state_dict = model.state_dict() - - if ( - text_encoder_lora_layers is not None - and text_encoder_keys is not None - and state_dict.keys() == text_encoder_keys - ): - # text encoder - text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_keys: - # unet - unet_lora_layers_to_save = state_dict + # unet + unet_lora_layers_to_save = state_dict # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -904,7 +826,7 @@ def main(args): LoraLoaderMixin.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_lora_layers_to_save, + text_encoder_lora_layers=None, ) def load_model_hook(models, input_dir): @@ -957,11 +879,7 @@ def main(args): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) - if args.train_text_encoder - else unet_lora_layers.parameters() - ) + params_to_optimize = unet_lora_layers.parameters() optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -970,28 +888,33 @@ def main(args): eps=args.adam_epsilon, ) - # We always pre-compute the additional condition embeddings needed for SDXL + # We ALWAYS pre-compute the additional condition embeddings needed for SDXL # UNet as the model is already big and it uses two text encoders. # TODO: when we add support for text encoder training, will reivist. tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] - def compute_embeddings(prompt): - prompt_embeds = pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) - add_text_embeds = pooled_prompt_embeds - crops_coords_top_left = (0, 0) - add_time_ids = torch.tensor( - [list(args.resolution + crops_coords_top_left + args.resolution)], dtype=torch.long - ) - prompt_embeds = prompt_embeds.to(accelerator.device) - add_text_embeds = add_text_embeds.to(accelerator.device) - unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + def compute_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + add_text_embeds = pooled_prompt_embeds + crops_coords_top_left = (0, 0) + add_time_ids = torch.tensor( + [list(args.resolution + crops_coords_top_left + args.resolution)], dtype=torch.long + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} return prompt_embeds, unet_added_cond_kwargs - instance_prompt_hidden_states, instance_unet_added_conditions = (compute_embeddings(args.instance_prompt),) + instance_prompt_hidden_states, instance_unet_added_conditions = ( + compute_embeddings(args.instance_prompt, text_encoders, tokenizers), + ) class_prompt_hidden_states, class_unet_added_conditions = None, None if args.with_prior_preservation: - class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings(args.class_prompt) + class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings( + args.class_prompt, text_encoders, tokenizers + ) del tokenizers, text_encoders @@ -1005,7 +928,6 @@ def main(args): class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, class_num=args.num_class_images, - tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, instance_prompt_hidden_states=instance_prompt_hidden_states, @@ -1039,14 +961,9 @@ def main(args): ) # Prepare everything with our `accelerator`. - if args.train_text_encoder: - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler - ) - else: - unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, optimizer, train_dataloader, lr_scheduler - ) + unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1105,8 +1022,6 @@ def main(args): for epoch in range(first_epoch, args.num_train_epochs): unet.train() - if args.train_text_encoder: - text_encoder.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -1137,36 +1052,11 @@ def main(args): # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # Get the text embedding for conditioning - if args.pre_compute_text_embeddings: - encoder_hidden_states = batch["input_ids"] - else: - encoder_hidden_states = encode_prompt( - text_encoder, - batch["input_ids"], - batch["attention_mask"], - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) - - if accelerator.unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) - - if args.class_labels_conditioning == "timesteps": - class_labels = timesteps - else: - class_labels = None - # Predict the noise residual model_pred = unet( - noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"] ).sample - # if model predicts variance, throw away the prediction. we will only train on the - # simplified training objective. This means that all schedulers using the fine tuned - # model must be configured to use one of the fixed variance variance types. - if model_pred.shape[1] == 6: - model_pred, _ = torch.chunk(model_pred, 2, dim=1) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -1193,11 +1083,7 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) - if args.train_text_encoder - else unet_lora_layers.parameters() - ) + params_to_clip = unet_lora_layers.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -1251,7 +1137,6 @@ def main(args): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) @@ -1276,13 +1161,7 @@ def main(args): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - if args.pre_compute_text_embeddings: - pipeline_args = { - "prompt_embeds": validation_prompt_encoder_hidden_states, - "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, - } - else: - pipeline_args = {"prompt": args.validation_prompt} + pipeline_args = {"prompt": args.validation_prompt} if args.validation_images is None: images = [ @@ -1319,14 +1198,10 @@ def main(args): unet = unet.to(torch.float32) unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) - if text_encoder is not None: - text_encoder = text_encoder.to(torch.float32) - text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) - LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_lora_layers=None, ) # Final inference diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 192f9663f2..b31536b213 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -97,7 +97,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c6a1cb9fac..9b20592775 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -47,7 +47,9 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import StableDiffusionXLPipeline - >>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" @@ -625,10 +627,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline): Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 523da227ac..eae0c47bff 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -49,7 +49,9 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import StableDiffusionXLPipeline - >>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" @@ -683,10 +685,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline): Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)