mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Refactor instructpix2pix lora to support peft (#10205)
* make base code changes referred from train_instructpix2pix script in examples * change code to use PEFT as discussed in issue 10062 * update README training command * update README training command * refactor variable name and freezing unet * Update examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update README installation instructions. * cleanup code using make style and quality --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -2,6 +2,34 @@
|
||||
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
|
||||
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
|
||||
|
||||
## Running locally with PyTorch
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
## Training script example
|
||||
|
||||
```bash
|
||||
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
|
||||
export DATASET_ID="instruction-tuning-sd/cartoonization"
|
||||
export OUTPUT_DIR="instructPix2Pix-cartoonization"
|
||||
|
||||
accelerate launch finetune_instruct_pix2pix.py \
|
||||
accelerate launch train_instruct_pix2pix_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_ID \
|
||||
--dataset_name=$DATASET_ID \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
|
||||
--rank=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--report_to=wandb \
|
||||
--push_to_hub
|
||||
--push_to_hub \
|
||||
--original_image_column="original_image" \
|
||||
--edited_image_column="cartoonized_image" \
|
||||
--edit_prompt_column="edit_prompt"
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
|
||||
"""
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -30,6 +33,7 @@ import numpy as np
|
||||
import PIL
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
@@ -39,21 +43,28 @@ from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.training_utils import EMAModel, cast_training_params
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -63,6 +74,92 @@ DATASET_NAME_MAPPING = {
|
||||
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"instruct-pix2pix",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
|
||||
tracker.log({"validation": wandb_table})
|
||||
|
||||
return edited_images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
|
||||
parser.add_argument(
|
||||
@@ -417,11 +514,6 @@ def main():
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -467,49 +559,58 @@ def main():
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
|
||||
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
|
||||
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
|
||||
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
|
||||
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
|
||||
# initialized to zero.
|
||||
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
|
||||
in_channels = 8
|
||||
out_channels = unet.conv_in.out_channels
|
||||
unet.register_to_config(in_channels=in_channels)
|
||||
|
||||
with torch.no_grad():
|
||||
new_conv_in = nn.Conv2d(
|
||||
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
|
||||
unet.conv_in = new_conv_in
|
||||
|
||||
# Freeze vae, text_encoder and unet
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
# Freeze the unet parameters before adding adapters
|
||||
unet.requires_grad_(False)
|
||||
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
@@ -528,6 +629,13 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -540,7 +648,8 @@ def main():
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
@@ -589,9 +698,9 @@ def main():
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
# train on only unet_lora_parameters
|
||||
# train on only lora_layers
|
||||
optimizer = optimizer_cls(
|
||||
unet_lora_parameters,
|
||||
trainable_params,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -730,22 +839,27 @@ def main():
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
@@ -765,8 +879,14 @@ def main():
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -885,7 +1005,7 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
@@ -895,7 +1015,7 @@ def main():
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -903,7 +1023,7 @@ def main():
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet_lora_parameters)
|
||||
ema_unet.step(trainable_params)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
@@ -933,6 +1053,16 @@ def main():
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
|
||||
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -959,45 +1089,22 @@ def main():
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"validation": wandb_table})
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
@@ -1008,22 +1115,47 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
# store only LORA layers
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
|
||||
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unet,
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
# store only LORA layers
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
images = None
|
||||
if (args.val_image_url is not None) and (args.validation_prompt is not None):
|
||||
images = log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
dataset_name=args.dataset_name,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
@@ -1031,31 +1163,6 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"test": wandb_table})
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user