1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add Min-SNR loss to Controlnet flax train script (#3016)

* add wandb team and min-snr loss

* make style

* apply feedbacks
This commit is contained in:
YiYi Xu
2023-04-09 16:26:54 -10:00
committed by Daniel Gu
parent 5326243241
commit 7ee2817ae2
2 changed files with 36 additions and 6 deletions

View File

@@ -408,4 +408,8 @@ You can then start your training from this saved checkpoint with
```bash
--controlnet_model_name_or_path="./control_out/500"
```
```
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).

View File

@@ -289,6 +289,13 @@ def parse_args():
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
@@ -328,11 +335,8 @@ def parse_args():
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
default="wandb",
help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
)
parser.add_argument(
"--mixed_precision",
@@ -442,6 +446,7 @@ def parse_args():
" `args.validation_prompt` and logging the images."
),
)
parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
parser.add_argument(
"--tracker_project_name",
type=str,
@@ -668,6 +673,7 @@ def main():
# wandb init
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.init(
entity=args.wandb_entity,
project=args.tracker_project_name,
job_type="train",
config=args,
@@ -806,6 +812,20 @@ def main():
validation_rng, train_rngs = jax.random.split(rng)
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
alpha = sqrt_alphas_cumprod[timesteps]
sigma = sqrt_one_minus_alphas_cumprod[timesteps]
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
# reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
if args.gradient_accumulation_steps > 1:
@@ -876,6 +896,12 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
loss = loss * snr_loss_weights
loss = loss.mean()
return loss