mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support EDM-style training in DreamBooth LoRA SDXL script (#7126)
* add: dreambooth lora script for Playground v2.5 * fix: kwarg * address suraj's comments. * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suraj's suggestion * incorporate changes in the canonical script./ * tracker naming * fix: schedule determination * add: two simple tests * remove playground script * note about edm-style training * address pedro's comments. * address part of Suraj's comments. * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * remove guidance_scale. * use mse_loss. * add comments for preconditioning. * quality * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * tackle v-pred. * Empty-Commit * support edm for sdxl too. * address suraj's comments. * Empty-Commit --------- Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin
|
||||
## Running on a free-tier Colab Notebook
|
||||
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
|
||||
|
||||
## Conducting EDM-style training
|
||||
|
||||
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
For the SDXL model, simple set:
|
||||
|
||||
```diff
|
||||
+ --do_edm_style_training \
|
||||
```
|
||||
|
||||
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
|
||||
--instance_data_dir="dog" \
|
||||
--output_dir="dog-playground-lora" \
|
||||
--mixed_precision="fp16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-4 \
|
||||
--use_8bit_adam \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
> [!CAUTION]
|
||||
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
|
||||
|
||||
99
examples/dreambooth/test_dreambooth_lora_edm.py
Normal file
99
examples/dreambooth/test_dreambooth_lora_edm.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
|
||||
def test_dreambooth_lora_sdxl_with_edm(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--do_edm_style_training
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_playground(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
@@ -14,8 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -32,7 +34,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
@@ -50,6 +52,8 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMEulerScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -76,6 +80,20 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
||||
model_index_filename = "model_index.json"
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
|
||||
else:
|
||||
model_index = hf_hub_download(
|
||||
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
|
||||
)
|
||||
|
||||
with open(model_index, "r") as f:
|
||||
scheduler_type = json.load(f)["scheduler"][1]
|
||||
return scheduler_type
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -95,7 +113,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# SDXL LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -119,11 +137,17 @@ Weights for this model are available in Safetensors format.
|
||||
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
if "playgroundai" in args.pretrained_model_name_or_path:
|
||||
model_description += """\n
|
||||
## License
|
||||
|
||||
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++",
|
||||
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -131,15 +155,17 @@ Weights for this model are available in Safetensors format.
|
||||
)
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"lora",
|
||||
"template:sd-lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
if "playgroundai" in base_model:
|
||||
tags.extend(["playground", "playground-diffusers"])
|
||||
else:
|
||||
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
|
||||
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
@@ -159,23 +185,29 @@ def log_validation(
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with inference_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -334,6 +366,12 @@ def parse_args(input_args=None):
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_edm_style_training",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -905,6 +943,9 @@ def main(args):
|
||||
" Please use `huggingface-cli login` to authenticate with the Hub."
|
||||
)
|
||||
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -1018,7 +1059,19 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
|
||||
if "EDM" in scheduler_type:
|
||||
args.do_edm_style_training = True
|
||||
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
logger.info("Performing EDM-style training!")
|
||||
elif args.do_edm_style_training:
|
||||
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
logger.info("Performing EDM-style training!")
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1036,6 +1089,12 @@ def main(args):
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1433,7 +1492,12 @@ def main(args):
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
||||
tracker_name = (
|
||||
"dreambooth-lora-sd-xl"
|
||||
if "playgroundai" not in args.pretrained_model_name_or_path
|
||||
else "dreambooth-lora-playground"
|
||||
)
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
@@ -1485,6 +1549,18 @@ def main(args):
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
|
||||
timesteps = timesteps.to(accelerator.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
@@ -1512,22 +1588,46 @@ def main(args):
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
if not args.do_edm_style_training:
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
|
||||
# instead of discrete timesteps, so here we sample indices to get the noise levels
|
||||
# from `scheduler.timesteps`
|
||||
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
|
||||
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if args.do_edm_style_training:
|
||||
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
|
||||
if "EDM" in scheduler_type:
|
||||
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
|
||||
else:
|
||||
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
@@ -1551,7 +1651,7 @@ def main(args):
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
@@ -1570,18 +1670,43 @@ def main(args):
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
weighting = None
|
||||
if args.do_edm_style_training:
|
||||
# Similar to the input preconditioning, the model predictions are also preconditioned
|
||||
# on noised model inputs (before preconditioning) and the sigmas.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if "EDM" in scheduler_type:
|
||||
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
|
||||
else:
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
|
||||
noisy_model_input / (sigmas**2 + 1)
|
||||
)
|
||||
# We are not doing weighting here because it tends result in numerical problems.
|
||||
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
# There might be other alternatives for weighting as well:
|
||||
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
|
||||
if "EDM" not in scheduler_type:
|
||||
weighting = (sigmas**-2.0).float()
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
target = model_input if args.do_edm_style_training else noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
target = (
|
||||
model_input
|
||||
if args.do_edm_style_training
|
||||
else noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -1591,10 +1716,28 @@ def main(args):
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
prior_loss = prior_loss.mean()
|
||||
else:
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
loss = torch.mean(
|
||||
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
|
||||
target.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
@@ -1696,7 +1839,6 @@ def main(args):
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
images = log_validation(
|
||||
|
||||
Reference in New Issue
Block a user