mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add Min-SNR loss to Controlnet flax train script (#3016)
* add wandb team and min-snr loss * make style * apply feedbacks
This commit is contained in:
@@ -408,4 +408,8 @@ You can then start your training from this saved checkpoint with
|
||||
|
||||
```bash
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
```
|
||||
```
|
||||
|
||||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
|
||||
|
||||
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
|
||||
@@ -289,6 +289,13 @@ def parse_args():
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
@@ -328,11 +335,8 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
default="wandb",
|
||||
help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -442,6 +446,7 @@ def parse_args():
|
||||
" `args.validation_prompt` and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
@@ -668,6 +673,7 @@ def main():
|
||||
# wandb init
|
||||
if jax.process_index() == 0 and args.report_to == "wandb":
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.tracker_project_name,
|
||||
job_type="train",
|
||||
config=args,
|
||||
@@ -806,6 +812,20 @@ def main():
|
||||
validation_rng, train_rngs = jax.random.split(rng)
|
||||
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
alpha = sqrt_alphas_cumprod[timesteps]
|
||||
sigma = sqrt_one_minus_alphas_cumprod[timesteps]
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
|
||||
# reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
@@ -876,6 +896,12 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = (target - model_pred) ** 2
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user