1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Correct multi gpu dreambooth (#3673)

Correct multi gpu
This commit is contained in:
Patrick von Platen
2023-06-05 12:03:11 +02:00
committed by GitHub
parent 0fc2fb71c1
commit 262d539a8a
2 changed files with 2 additions and 2 deletions

View File

@@ -1211,7 +1211,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if unet.config.in_channels == channels * 2:
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":

View File

@@ -1156,7 +1156,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if unet.config.in_channels == channels * 2:
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":