mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve LCM(-LoRA) Distillation Scripts (#6420)
* Make WDS pipeline interpolation type configurable. * Make the VAE encoding batch size configurable. * Make lora_alpha and lora_dropout configurable for LCM LoRA scripts. * Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable. * Make LoRA target modules configurable for LCM-LoRA scripts. * Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script. * apply suggestions from review
This commit is contained in:
@@ -61,6 +61,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -165,6 +166,7 @@ class SDText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 512,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -174,10 +176,12 @@ class SDText2ImageDataset:
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -353,8 +357,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -572,6 +577,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -710,6 +724,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -915,9 +973,10 @@ def main(args):
|
||||
)
|
||||
|
||||
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -932,7 +991,12 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
|
||||
@@ -1051,6 +1115,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1162,10 +1227,10 @@ def main(args):
|
||||
if vae.dtype != weight_dtype:
|
||||
vae.to(dtype=weight_dtype)
|
||||
|
||||
# encode pixel values with batch size of at most 32
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 32):
|
||||
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1181,9 +1246,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -51,6 +51,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -193,8 +194,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -396,6 +398,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -534,6 +545,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -776,10 +831,10 @@ def main(args):
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
lora_alpha=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -794,7 +849,12 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet.add_adapter(lora_config)
|
||||
|
||||
@@ -929,7 +989,8 @@ def main(args):
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
|
||||
train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
|
||||
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
||||
@@ -1121,11 +1182,11 @@ def main(args):
|
||||
|
||||
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
pixel_values = pixel_values.to(dtype=vae.dtype)
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], args.encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1142,9 +1203,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -62,6 +62,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -171,6 +172,7 @@ class SDXLText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 1024,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -187,10 +189,12 @@ class SDXLText2ImageDataset:
|
||||
else:
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -340,8 +344,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -546,6 +551,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fix_crop_and_size",
|
||||
action="store_true",
|
||||
@@ -690,6 +704,50 @@ def parse_args():
|
||||
default=64,
|
||||
help="The rank of the LoRA projection matrix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
||||
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
||||
" be used. By default, LoRA will be applied to all conv and linear layers."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Mixed Precision----
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -929,9 +987,10 @@ def main(args):
|
||||
)
|
||||
|
||||
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=[
|
||||
if args.lora_target_modules is not None:
|
||||
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
||||
else:
|
||||
lora_target_modules = [
|
||||
"to_q",
|
||||
"to_k",
|
||||
"to_v",
|
||||
@@ -946,7 +1005,12 @@ def main(args):
|
||||
"downsamplers.0.conv",
|
||||
"upsamplers.0.conv",
|
||||
"time_emb_proj",
|
||||
],
|
||||
]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_rank,
|
||||
target_modules=lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
|
||||
@@ -1090,6 +1154,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1214,10 +1279,10 @@ def main(args):
|
||||
else:
|
||||
pixel_values = image
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 8):
|
||||
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1234,9 +1299,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -60,6 +60,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -147,6 +148,7 @@ class SDText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 512,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -156,10 +158,12 @@ class SDText2ImageDataset:
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -330,8 +334,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -549,6 +554,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -690,6 +704,26 @@ def parse_args():
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1034,6 +1068,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1145,10 +1180,10 @@ def main(args):
|
||||
if vae.dtype != weight_dtype:
|
||||
vae.to(dtype=weight_dtype)
|
||||
|
||||
# encode pixel values with batch size of at most 32
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 32):
|
||||
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1164,9 +1199,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -61,6 +61,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -153,6 +154,7 @@ class SDXLText2ImageDataset:
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
resolution: int = 1024,
|
||||
interpolation_type: str = "bilinear",
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
@@ -169,10 +171,12 @@ class SDXLText2ImageDataset:
|
||||
else:
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
interpolation_mode = resolve_interpolation_mode(interpolation_type)
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = TF.resize(image, resolution, interpolation=interpolation_mode)
|
||||
|
||||
# get crop coordinates and crop image
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
@@ -318,8 +322,9 @@ def append_dims(x, target_dims):
|
||||
|
||||
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
||||
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
||||
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
||||
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
||||
scaled_timestep = timestep_scaling * timestep
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
|
||||
@@ -568,6 +573,15 @@ def parse_args():
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation_type",
|
||||
type=str,
|
||||
default="bilinear",
|
||||
help=(
|
||||
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
||||
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fix_crop_and_size",
|
||||
action="store_true",
|
||||
@@ -715,6 +729,26 @@ def parse_args():
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_encode_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
required=False,
|
||||
help=(
|
||||
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
||||
" Encoding or decoding the whole batch at once may run into OOM issues."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_scaling_factor",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help=(
|
||||
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
||||
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
||||
" suffice."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1118,6 +1152,7 @@ def main(args):
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
interpolation_type=args.interpolation_type,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
@@ -1242,10 +1277,10 @@ def main(args):
|
||||
else:
|
||||
pixel_values = image
|
||||
|
||||
# encode pixel values with batch size of at most 8
|
||||
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
||||
latents = []
|
||||
for i in range(0, pixel_values.shape[0], 8):
|
||||
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
|
||||
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
|
||||
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
||||
latents = torch.cat(latents, dim=0)
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -1262,9 +1297,13 @@ def main(args):
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
||||
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = scalings_for_boundary_conditions(
|
||||
timesteps, timestep_scaling=args.timestep_scaling_factor
|
||||
)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .utils import deprecate, is_transformers_available
|
||||
@@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
return snr
|
||||
|
||||
|
||||
def resolve_interpolation_mode(interpolation_type: str):
|
||||
"""
|
||||
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
||||
full list of supported enums is documented at
|
||||
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
||||
|
||||
Args:
|
||||
interpolation_type (`str`):
|
||||
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
||||
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
||||
in torchvision.
|
||||
|
||||
Returns:
|
||||
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
||||
transform.
|
||||
"""
|
||||
if interpolation_type == "bilinear":
|
||||
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
||||
elif interpolation_type == "bicubic":
|
||||
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
||||
elif interpolation_type == "box":
|
||||
interpolation_mode = transforms.InterpolationMode.BOX
|
||||
elif interpolation_type == "nearest":
|
||||
interpolation_mode = transforms.InterpolationMode.NEAREST
|
||||
elif interpolation_type == "nearest_exact":
|
||||
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
||||
elif interpolation_type == "hamming":
|
||||
interpolation_mode = transforms.InterpolationMode.HAMMING
|
||||
elif interpolation_type == "lanczos":
|
||||
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
||||
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
)
|
||||
|
||||
return interpolation_mode
|
||||
|
||||
|
||||
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user