mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve LCM(-LoRA) Distillation Scripts (#6420)
* Make WDS pipeline interpolation type configurable. * Make the VAE encoding batch size configurable. * Make lora_alpha and lora_dropout configurable for LCM LoRA scripts. * Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable. * Make LoRA target modules configurable for LCM-LoRA scripts. * Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script. * apply suggestions from review
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .utils import deprecate, is_transformers_available
|
||||
@@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
return snr
|
||||
|
||||
|
||||
def resolve_interpolation_mode(interpolation_type: str):
|
||||
"""
|
||||
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
||||
full list of supported enums is documented at
|
||||
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
||||
|
||||
Args:
|
||||
interpolation_type (`str`):
|
||||
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
||||
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
||||
in torchvision.
|
||||
|
||||
Returns:
|
||||
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
||||
transform.
|
||||
"""
|
||||
if interpolation_type == "bilinear":
|
||||
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
||||
elif interpolation_type == "bicubic":
|
||||
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
||||
elif interpolation_type == "box":
|
||||
interpolation_mode = transforms.InterpolationMode.BOX
|
||||
elif interpolation_type == "nearest":
|
||||
interpolation_mode = transforms.InterpolationMode.NEAREST
|
||||
elif interpolation_type == "nearest_exact":
|
||||
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
||||
elif interpolation_type == "hamming":
|
||||
interpolation_mode = transforms.InterpolationMode.HAMMING
|
||||
elif interpolation_type == "lanczos":
|
||||
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
||||
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
||||
)
|
||||
|
||||
return interpolation_mode
|
||||
|
||||
|
||||
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user