From 262d539a8a8f505dc72958f7ea50915a4b56dfac Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 12:03:11 +0200 Subject: [PATCH] Correct multi gpu dreambooth (#3673) Correct multi gpu --- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ad03829fd1..97b7f334bc 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1211,7 +1211,7 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 49aef1cc4a..ca25152fcb 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1156,7 +1156,7 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps":