mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Research] Latent Perceptual Loss (LPL) for Stable Diffusion XL (#11573)
* initial * added readme * fix formatting * added logging * formatting * use config * debug * better * handle SNR * floats have no item() * remove debug * formatting * add paper link * acknowledge reference source * rename script --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
157
examples/research_projects/lpl/README.md
Normal file
157
examples/research_projects/lpl/README.md
Normal file
@@ -0,0 +1,157 @@
|
||||
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
|
||||
|
||||
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
|
||||
|
||||
## Overview
|
||||
|
||||
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
|
||||
|
||||
- Loss of fine details in generated images
|
||||
- Inconsistent image quality
|
||||
- Structural artifacts
|
||||
- Reduced sharpness and realism
|
||||
|
||||
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
|
||||
|
||||
- Improved image quality and consistency (6-20% FID improvement)
|
||||
- Better preservation of fine details
|
||||
- More stable training, especially at high noise levels
|
||||
- Better handling of structural information
|
||||
- Sharper and more realistic textures
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The LPL implementation follows the paper's methodology and includes several key features:
|
||||
|
||||
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
|
||||
- Middle block features
|
||||
- Up block features (configurable number of blocks)
|
||||
- Proper gradient checkpointing for memory efficiency
|
||||
- Features are extracted only for timesteps below the threshold (high SNR)
|
||||
|
||||
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
|
||||
- `default`: Normalize each feature map independently
|
||||
- `shared`: Cross-normalize features using target statistics (recommended)
|
||||
- `batch`: Batch-wise normalization
|
||||
|
||||
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
|
||||
- Quantile-based filtering (2% quantiles)
|
||||
- Morphological operations (opening/closing)
|
||||
- Adaptive thresholding based on standard deviation
|
||||
|
||||
4. **Loss Types**:
|
||||
- MSE loss (default)
|
||||
- L1 loss
|
||||
- Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
## Usage
|
||||
|
||||
To use LPL in your training, add the following arguments to your training command:
|
||||
|
||||
```bash
|
||||
python examples/research_projects/lpl/train_sdxl_lpl.py \
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
|
||||
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
|
||||
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
|
||||
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
|
||||
--lpl_pow_law \ # Use power law weighting for layers
|
||||
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
|
||||
--lpl_remove_outliers \ # Remove outliers in feature maps
|
||||
--lpl_scale \ # Scale LPL loss by noise level weights
|
||||
--lpl_start 0 \ # Step to start applying LPL
|
||||
# ... other training arguments ...
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
|
||||
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
|
||||
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
|
||||
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
|
||||
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
|
||||
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
|
||||
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
|
||||
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
|
||||
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Starting Point** (based on paper results):
|
||||
```bash
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \
|
||||
--lpl_t_threshold 200 \
|
||||
--lpl_loss_type mse \
|
||||
--lpl_norm_type shared \
|
||||
--lpl_pow_law \
|
||||
--lpl_num_blocks 4 \
|
||||
--lpl_remove_outliers \
|
||||
--lpl_scale
|
||||
```
|
||||
|
||||
2. **Memory Efficiency**:
|
||||
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
|
||||
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
|
||||
- Consider using `--lpl_scale` to focus on more important timesteps
|
||||
- Features are extracted only for timesteps below threshold to save memory
|
||||
|
||||
3. **Quality vs Speed**:
|
||||
- Higher `lpl_weight` (1.0-2.0) for better quality
|
||||
- Lower `lpl_t_threshold` (100-200) for faster training
|
||||
- Use `lpl_remove_outliers` for more stable training
|
||||
- `lpl_norm_type shared` provides best quality/speed trade-off
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
The LPL implementation extracts features from the VAE decoder in the following order:
|
||||
1. Middle block output
|
||||
2. Up block outputs (configurable number of blocks)
|
||||
|
||||
Each feature map is processed with:
|
||||
1. Optional outlier removal (2% quantiles, morphological operations)
|
||||
2. Feature normalization (shared statistics recommended)
|
||||
3. Loss calculation (MSE or L1)
|
||||
4. Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
### Loss Calculation
|
||||
|
||||
For each feature map:
|
||||
1. Features are normalized according to the chosen strategy
|
||||
2. Loss is calculated between normalized features
|
||||
3. Outliers are masked out (if enabled)
|
||||
4. Loss is weighted by layer depth (if power law enabled)
|
||||
5. Final loss is averaged across all layers
|
||||
|
||||
### Memory Considerations
|
||||
|
||||
- Gradient checkpointing is used by default
|
||||
- Features are extracted only for timesteps below the threshold
|
||||
- Outlier removal is done in-place to save memory
|
||||
- Feature normalization is done efficiently using vectorized operations
|
||||
- Memory usage scales linearly with number of blocks used
|
||||
|
||||
## Results
|
||||
|
||||
Based on the paper's findings, LPL provides:
|
||||
- 6-20% improvement in FID scores
|
||||
- Better preservation of fine details
|
||||
- More realistic textures and structures
|
||||
- Improved consistency across different resolutions
|
||||
- Better performance on both small and large datasets
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this implementation in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{berrada2025boosting,
|
||||
title={Boosting Latent Diffusion with Perceptual Objectives},
|
||||
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
|
||||
booktitle={The Thirteenth International Conference on Learning Representations},
|
||||
year={2025},
|
||||
url={https://openreview.net/forum?id=y4DtzADzd1}
|
||||
}
|
||||
```
|
||||
215
examples/research_projects/lpl/lpl_loss.py
Normal file
215
examples/research_projects/lpl/lpl_loss.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2025 Berrada et al.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
||||
return in_feat / (norm_factor + eps)
|
||||
|
||||
|
||||
def cross_normalize(input, target, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
|
||||
return input / (norm_factor + eps), target / (norm_factor + eps)
|
||||
|
||||
|
||||
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
|
||||
opening = int(np.ceil(opening / down_f))
|
||||
closing = int(np.ceil(closing / down_f))
|
||||
if opening == 2:
|
||||
opening = 3
|
||||
if closing == 2:
|
||||
closing = 1
|
||||
|
||||
# replace quantile with kth value here.
|
||||
feat_flat = feat.flatten(-2, -1)
|
||||
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
|
||||
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
|
||||
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
|
||||
|
||||
m = 2 * feat_flat.std(-1)[..., None, None].detach()
|
||||
mask = (q1 - m < feat) * (feat < q2 + m)
|
||||
|
||||
# dilate the mask.
|
||||
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
|
||||
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
|
||||
feat = feat * mask
|
||||
return mask, feat
|
||||
|
||||
|
||||
class LatentPerceptualLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
loss_type="mse",
|
||||
grad_ckpt=True,
|
||||
pow_law=False,
|
||||
norm_type="default",
|
||||
num_mid_blocks=4,
|
||||
feature_type="feature",
|
||||
remove_outliers=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.decoder = self.vae.decoder
|
||||
# Store scaling factors as tensors on the correct device
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
# Get scaling factors with proper defaults and handle None values
|
||||
scale_factor = getattr(self.vae.config, "scaling_factor", None)
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", None)
|
||||
|
||||
# Convert to tensors with proper defaults
|
||||
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
|
||||
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
|
||||
|
||||
self.gradient_checkpointing = grad_ckpt
|
||||
self.pow_law = pow_law
|
||||
self.norm_type = norm_type.lower()
|
||||
self.outlier_mask = remove_outliers
|
||||
self.last_feature_stats = [] # Store feature statistics for logging
|
||||
|
||||
assert feature_type in ["feature", "image"]
|
||||
self.feature_type = feature_type
|
||||
|
||||
assert self.norm_type in ["default", "shared", "batch"]
|
||||
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
|
||||
self.n_blocks = num_mid_blocks
|
||||
|
||||
assert loss_type in ["mse", "l1"]
|
||||
if loss_type == "mse":
|
||||
self.loss_fn = nn.MSELoss(reduction="none")
|
||||
elif loss_type == "l1":
|
||||
self.loss_fn = nn.L1Loss(reduction="none")
|
||||
|
||||
def get_features(self, z, latent_embeds=None, disable_grads=False):
|
||||
with torch.set_grad_enabled(not disable_grads):
|
||||
if self.gradient_checkpointing and not disable_grads:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.decoder.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
features.append(sample)
|
||||
return features
|
||||
else:
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.decoder.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
features.append(sample)
|
||||
return features
|
||||
|
||||
def get_loss(self, input, target, get_hist=False):
|
||||
if self.feature_type == "feature":
|
||||
inp_f = self.get_features(self.shift + input / self.scale)
|
||||
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
|
||||
losses = []
|
||||
self.last_feature_stats = [] # Reset feature stats
|
||||
|
||||
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
|
||||
my = torch.ones_like(y).bool()
|
||||
outlier_ratio = 0.0
|
||||
|
||||
if self.outlier_mask:
|
||||
with torch.no_grad():
|
||||
if i == 2:
|
||||
my, y = remove_outliers(y, down_f=2)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
elif i in [3, 4, 5]:
|
||||
my, y = remove_outliers(y, down_f=1)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
|
||||
# Store feature statistics before normalization
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": y.mean().item(),
|
||||
"std": y.std().item(),
|
||||
"outlier_ratio": outlier_ratio,
|
||||
}
|
||||
self.last_feature_stats.append(stats)
|
||||
|
||||
# normalize feature tensors
|
||||
if self.norm_type == "default":
|
||||
x = normalize_tensor(x)
|
||||
y = normalize_tensor(y)
|
||||
elif self.norm_type == "shared":
|
||||
x, y = cross_normalize(x, y, eps=1e-6)
|
||||
|
||||
term_loss = self.loss_fn(x, y) * my
|
||||
# reduce loss term
|
||||
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
|
||||
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
|
||||
losses.append(term_loss.mean((1,)))
|
||||
|
||||
if get_hist:
|
||||
return losses
|
||||
else:
|
||||
loss = sum(losses)
|
||||
return loss / len(inp_f)
|
||||
elif self.feature_type == "image":
|
||||
inp_f = self.vae.decode(input / self.scale).sample
|
||||
tar_f = self.vae.decode(target / self.scale).sample
|
||||
return F.mse_loss(inp_f, tar_f)
|
||||
|
||||
def get_first_conv(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
return sample
|
||||
|
||||
def get_first_block(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
sample = self.decoder.mid_block(sample)
|
||||
for resnet in self.decoder.up_blocks[0].resnets:
|
||||
sample = resnet(sample, None)
|
||||
return sample
|
||||
|
||||
def get_first_layer(self, input, target, target_layer="conv"):
|
||||
if target_layer == "conv":
|
||||
feat_in = self.get_first_conv(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_conv(target)
|
||||
else:
|
||||
feat_in = self.get_first_block(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_block(target)
|
||||
|
||||
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
|
||||
|
||||
return F.mse_loss(feat_in, feat_tar, reduction="mean")
|
||||
1622
examples/research_projects/lpl/train_sdxl_lpl.py
Normal file
1622
examples/research_projects/lpl/train_sdxl_lpl.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user