mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
finish one.
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
import argparse
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -47,7 +46,6 @@ from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
@@ -60,7 +58,7 @@ from diffusers.models.attention_processor import (
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -90,8 +88,8 @@ license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
|
||||
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
|
||||
- 'stable-diffusion-xl'
|
||||
- 'stable-diffusion-xl-diffusers'
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
@@ -110,10 +108,12 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
@@ -122,14 +122,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
elif model_class == "CLIPTextModelWithProjection":
|
||||
from transformers import CLIPTextModelWithProjection
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
elif model_class == "T5EncoderModel":
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
return T5EncoderModel
|
||||
return CLIPTextModelWithProjection
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
@@ -150,12 +146,6 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
@@ -405,37 +395,6 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre_compute_text_embeddings",
|
||||
action="store_true",
|
||||
help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_max_length",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_use_attention_mask",
|
||||
action="store_true",
|
||||
required=False,
|
||||
help="Whether to use attention mask for the text encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
required=False,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_labels_conditioning",
|
||||
required=False,
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -557,8 +516,8 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
unet_added_conditions = [example["instance_added_cond_kwargs"] for example in examples]
|
||||
|
||||
add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples]
|
||||
add_time_ids = [example["instance_added_cond_kwargs"]["add_time_ids"] for example in examples]
|
||||
if has_attention_mask:
|
||||
attention_mask = [example["instance_attention_mask"] for example in examples]
|
||||
|
||||
@@ -567,7 +526,9 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
if with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
unet_added_conditions += [example["class_added_cond_kwargs"] for example in examples]
|
||||
add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples]
|
||||
add_time_ids += [example["class_added_cond_kwargs"]["add_time_ids"] for example in examples]
|
||||
|
||||
if has_attention_mask:
|
||||
attention_mask += [example["class_attention_mask"] for example in examples]
|
||||
|
||||
@@ -576,7 +537,11 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
|
||||
input_ids = torch.cat(input_ids, dim=0)
|
||||
|
||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "unet_added_conditions": unet_added_conditions}
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
|
||||
}
|
||||
|
||||
if has_attention_mask:
|
||||
batch["attention_mask"] = attention_mask
|
||||
@@ -658,14 +623,8 @@ def main(args):
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.train_text_encoder:
|
||||
raise NotImplementedError("Text encoder training not yet supported.")
|
||||
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -742,50 +701,45 @@ def main(args):
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
# import correct text encoder classes
|
||||
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
||||
args.pretrained_model_name_or_path, args.revision
|
||||
)
|
||||
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
||||
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder_one = text_encoder_cls.from_pretrained(
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
text_encoder_two = text_encoder_cls.from_pretrained(
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
)
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
except OSError:
|
||||
# IF does not have a VAE so let's just set it to None
|
||||
# We don't have to error out here
|
||||
vae = None
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
if vae is not None:
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder_one.requires_grad_(False)
|
||||
text_encoder_two.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
@@ -798,9 +752,9 @@ def main(args):
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
@@ -854,49 +808,17 @@ def main(args):
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
||||
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
|
||||
text_encoder_lora_layers = None
|
||||
if args.train_text_encoder:
|
||||
text_lora_attn_procs = {}
|
||||
for name, module in text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=module.out_proj.out_features, cross_attention_dim=None
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||
)
|
||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||
text_encoder = temp_pipeline.text_encoder
|
||||
del temp_pipeline
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
|
||||
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
|
||||
accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
|
||||
|
||||
for model in models:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if (
|
||||
text_encoder_lora_layers is not None
|
||||
and text_encoder_keys is not None
|
||||
and state_dict.keys() == text_encoder_keys
|
||||
):
|
||||
# text encoder
|
||||
text_encoder_lora_layers_to_save = state_dict
|
||||
elif state_dict.keys() == unet_keys:
|
||||
# unet
|
||||
unet_lora_layers_to_save = state_dict
|
||||
# unet
|
||||
unet_lora_layers_to_save = state_dict
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
@@ -904,7 +826,7 @@ def main(args):
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -957,11 +879,7 @@ def main(args):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
params_to_optimize = unet_lora_layers.parameters()
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -970,28 +888,33 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# We always pre-compute the additional condition embeddings needed for SDXL
|
||||
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
|
||||
# UNet as the model is already big and it uses two text encoders.
|
||||
# TODO: when we add support for text encoder training, will reivist.
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
|
||||
def compute_embeddings(prompt):
|
||||
prompt_embeds = pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = torch.tensor(
|
||||
[list(args.resolution + crops_coords_top_left + args.resolution)], dtype=torch.long
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
add_text_embeds = add_text_embeds.to(accelerator.device)
|
||||
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
def compute_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = torch.tensor(
|
||||
[list(args.resolution + crops_coords_top_left + args.resolution)], dtype=torch.long
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
add_text_embeds = add_text_embeds.to(accelerator.device)
|
||||
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
return prompt_embeds, unet_added_cond_kwargs
|
||||
|
||||
instance_prompt_hidden_states, instance_unet_added_conditions = (compute_embeddings(args.instance_prompt),)
|
||||
instance_prompt_hidden_states, instance_unet_added_conditions = (
|
||||
compute_embeddings(args.instance_prompt, text_encoders, tokenizers),
|
||||
)
|
||||
class_prompt_hidden_states, class_unet_added_conditions = None, None
|
||||
if args.with_prior_preservation:
|
||||
class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings(args.class_prompt)
|
||||
class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
del tokenizers, text_encoders
|
||||
|
||||
@@ -1005,7 +928,6 @@ def main(args):
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
class_num=args.num_class_images,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
instance_prompt_hidden_states=instance_prompt_hidden_states,
|
||||
@@ -1039,14 +961,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1105,8 +1022,6 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -1137,36 +1052,11 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
if args.pre_compute_text_embeddings:
|
||||
encoder_hidden_states = batch["input_ids"]
|
||||
else:
|
||||
encoder_hidden_states = encode_prompt(
|
||||
text_encoder,
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
class_labels = timesteps
|
||||
else:
|
||||
class_labels = None
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
|
||||
noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"]
|
||||
).sample
|
||||
|
||||
# if model predicts variance, throw away the prediction. we will only train on the
|
||||
# simplified training objective. This means that all schedulers using the fine tuned
|
||||
# model must be configured to use one of the fixed variance variance types.
|
||||
if model_pred.shape[1] == 6:
|
||||
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
@@ -1193,11 +1083,7 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
params_to_clip = unet_lora_layers.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -1251,7 +1137,6 @@ def main(args):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
@@ -1276,13 +1161,7 @@ def main(args):
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
if args.pre_compute_text_embeddings:
|
||||
pipeline_args = {
|
||||
"prompt_embeds": validation_prompt_encoder_hidden_states,
|
||||
"negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
|
||||
}
|
||||
else:
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if args.validation_images is None:
|
||||
images = [
|
||||
@@ -1319,14 +1198,10 @@ def main(args):
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -97,7 +97,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
|
||||
@@ -47,7 +47,9 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
>>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -625,10 +627,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
@@ -49,7 +49,9 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
>>> pipe = StableDiffusionXLPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -683,10 +685,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
Reference in New Issue
Block a user