mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
if dreambooth lora (#3360)
* update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
@@ -30,7 +31,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub import create_repo, model_info, upload_folder
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
@@ -48,7 +49,13 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -108,6 +115,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
elif model_class == "T5EncoderModel":
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
return T5EncoderModel
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
@@ -387,6 +398,24 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre_compute_text_embeddings",
|
||||
action="store_true",
|
||||
help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_max_length",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_use_attention_mask",
|
||||
action="store_true",
|
||||
required=False,
|
||||
help="Whether to use attention mask for the text encoder",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -409,6 +438,9 @@ def parse_args(input_args=None):
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
if args.train_text_encoder and args.pre_compute_text_embeddings:
|
||||
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -428,10 +460,16 @@ class DreamBoothDataset(Dataset):
|
||||
class_num=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
encoder_hidden_states=None,
|
||||
instance_prompt_encoder_hidden_states=None,
|
||||
tokenizer_max_length=None,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
self.encoder_hidden_states = encoder_hidden_states
|
||||
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
|
||||
self.tokenizer_max_length = tokenizer_max_length
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
@@ -473,39 +511,50 @@ class DreamBoothDataset(Dataset):
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.encoder_hidden_states is not None:
|
||||
example["instance_prompt_ids"] = self.encoder_hidden_states
|
||||
else:
|
||||
text_inputs = tokenize_prompt(
|
||||
self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
|
||||
)
|
||||
example["instance_prompt_ids"] = text_inputs.input_ids
|
||||
example["instance_attention_mask"] = text_inputs.attention_mask
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
self.class_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.instance_prompt_encoder_hidden_states is not None:
|
||||
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
|
||||
else:
|
||||
class_text_inputs = tokenize_prompt(
|
||||
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
|
||||
)
|
||||
example["class_prompt_ids"] = class_text_inputs.input_ids
|
||||
example["class_attention_mask"] = class_text_inputs.attention_mask
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
has_attention_mask = "instance_attention_mask" in examples[0]
|
||||
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
if has_attention_mask:
|
||||
attention_mask = [example["instance_attention_mask"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
if has_attention_mask:
|
||||
attention_mask += [example["class_attention_mask"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
@@ -516,6 +565,10 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
if has_attention_mask:
|
||||
batch["attention_mask"] = attention_mask
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -536,6 +589,50 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def model_has_vae(args):
|
||||
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
|
||||
if os.path.isdir(args.pretrained_model_name_or_path):
|
||||
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
|
||||
return os.path.isfile(config_file_name)
|
||||
else:
|
||||
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
|
||||
return any(file.rfilename == config_file_name for file in files_in_repo)
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
|
||||
if tokenizer_max_length is not None:
|
||||
max_length = tokenizer_max_length
|
||||
else:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return text_inputs
|
||||
|
||||
|
||||
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
|
||||
text_input_ids = input_ids.to(text_encoder.device)
|
||||
|
||||
if text_encoder_use_attention_mask:
|
||||
attention_mask = attention_mask.to(text_encoder.device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
@@ -656,13 +753,20 @@ def main(args):
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
if model_has_vae(args):
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
else:
|
||||
vae = None
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
vae.requires_grad_(False)
|
||||
if vae is not None:
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
@@ -676,7 +780,8 @@ def main(args):
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
@@ -707,7 +812,7 @@ def main(args):
|
||||
|
||||
# Set correct lora layers
|
||||
unet_lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
for name, attn_processor in unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
@@ -718,7 +823,12 @@ def main(args):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
unet_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
@@ -783,6 +893,44 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
if args.pre_compute_text_embeddings:
|
||||
|
||||
def compute_text_embeddings(prompt):
|
||||
with torch.no_grad():
|
||||
text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
|
||||
prompt_embeds = encode_prompt(
|
||||
text_encoder,
|
||||
text_inputs.input_ids,
|
||||
text_inputs.attention_mask,
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
|
||||
else:
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
if args.instance_prompt is not None:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
else:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
|
||||
text_encoder = None
|
||||
tokenizer = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
pre_computed_encoder_hidden_states = None
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
validation_prompt_negative_prompt_embeds = None
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
@@ -793,6 +941,9 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
encoder_hidden_states=pre_computed_encoder_hidden_states,
|
||||
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
|
||||
tokenizer_max_length=args.tokenizer_max_length,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -896,32 +1047,53 @@ def main(args):
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
if vae is not None:
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
else:
|
||||
model_input = pixel_values
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
if args.pre_compute_text_embeddings:
|
||||
encoder_hidden_states = batch["input_ids"]
|
||||
else:
|
||||
encoder_hidden_states = encode_prompt(
|
||||
text_encoder,
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# if model predicts variance, throw away the prediction. we will only train on the
|
||||
# simplified training objective. This means that all schedulers using the fine tuned
|
||||
# model must be configured to use one of the fixed variance variance types.
|
||||
if model_pred.shape[1] == 6:
|
||||
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -988,19 +1160,40 @@ def main(args):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
if args.pre_compute_text_embeddings:
|
||||
pipeline_args = {
|
||||
"prompt_embeds": validation_prompt_encoder_hidden_states,
|
||||
"negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
|
||||
}
|
||||
else:
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -1024,7 +1217,8 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
if text_encoder is not None:
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
@@ -1036,7 +1230,20 @@ def main(args):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
|
||||
@@ -292,6 +292,41 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_correct_naming)
|
||||
|
||||
def test_dreambooth_lora_if_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--pre_compute_text_embeddings
|
||||
--tokenizer_max_length=77
|
||||
--text_encoder_use_attention_mask
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -21,9 +21,13 @@ import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -250,10 +254,22 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
||||
|
||||
attn_processors[key] = LoRAAttnProcessor(
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
attn_processor = getattr(attn_processor, sub_key)
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
@@ -669,6 +669,73 @@ class AttnAddedKVProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
||||
encoder_hidden_states
|
||||
)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
||||
encoder_hidden_states
|
||||
)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
||||
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
key = encoder_hidden_states_key_proj
|
||||
value = encoder_hidden_states_value_proj
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XFormersAttnProcessor:
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
@@ -1022,6 +1089,7 @@ AttentionProcessor = Union[
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -85,7 +86,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFPipeline(DiffusionPipeline):
|
||||
class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -804,6 +805,9 @@ class IFPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
|
||||
|
||||
@@ -9,6 +9,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -109,7 +110,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFImg2ImgPipeline(DiffusionPipeline):
|
||||
class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -929,6 +930,9 @@ class IFImg2ImgPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
|
||||
|
||||
@@ -9,6 +9,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFInpaintingPipeline(DiffusionPipeline):
|
||||
class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -1044,6 +1045,9 @@ class IFInpaintingPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
prev_intermediate_images = intermediate_images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user