mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
0fc2fb71c1
commit
262d539a8a
@@ -1211,7 +1211,7 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if unet.config.in_channels == channels * 2:
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
|
||||
@@ -1156,7 +1156,7 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if unet.config.in_channels == channels * 2:
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
|
||||
Reference in New Issue
Block a user