1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve LCM(-LoRA) Distillation Scripts (#6420)

* Make WDS pipeline interpolation type configurable.

* Make the VAE encoding batch size configurable.

* Make lora_alpha and lora_dropout configurable for LCM LoRA scripts.

* Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable.

* Make LoRA target modules configurable for LCM-LoRA scripts.

* Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script.

* apply suggestions from review
This commit is contained in:
dg845
2024-01-04 17:25:13 -08:00
committed by GitHub
parent acd926f4f2
commit f3d1333e02
6 changed files with 374 additions and 53 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union
import numpy as np
import torch
from torchvision import transforms
from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available
@@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
full list of supported enums is documented at
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
Args:
interpolation_type (`str`):
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
in torchvision.
Returns:
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
"""
if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic":
interpolation_mode = transforms.InterpolationMode.BICUBIC
elif interpolation_type == "box":
interpolation_mode = transforms.InterpolationMode.BOX
elif interpolation_type == "nearest":
interpolation_mode = transforms.InterpolationMode.NEAREST
elif interpolation_type == "nearest_exact":
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
elif interpolation_type == "hamming":
interpolation_mode = transforms.InterpolationMode.HAMMING
elif interpolation_type == "lanczos":
interpolation_mode = transforms.InterpolationMode.LANCZOS
else:
raise ValueError(
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
)
return interpolation_mode
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns: