mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix InstructPix2Pix training in multi-GPU mode (#2978)
* fix: norm group test for UNet3D. * fix: unet rejig. * fix: unwrapping when running validation inputs. * unwrapping the unet too. * fix: device. * better unwrapping. * unwrapping before ema. * unwrapping.
This commit is contained in:
@@ -451,19 +451,18 @@ def main():
|
||||
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
|
||||
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
|
||||
# initialized to zero.
|
||||
if accelerator.is_main_process:
|
||||
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
|
||||
in_channels = 8
|
||||
out_channels = unet.conv_in.out_channels
|
||||
unet.register_to_config(in_channels=in_channels)
|
||||
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
|
||||
in_channels = 8
|
||||
out_channels = unet.conv_in.out_channels
|
||||
unet.register_to_config(in_channels=in_channels)
|
||||
|
||||
with torch.no_grad():
|
||||
new_conv_in = nn.Conv2d(
|
||||
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
||||
unet.conv_in = new_conv_in
|
||||
with torch.no_grad():
|
||||
new_conv_in = nn.Conv2d(
|
||||
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
||||
unet.conv_in = new_conv_in
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
@@ -892,9 +891,12 @@ def main():
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=unet,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
@@ -904,7 +906,9 @@ def main():
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
@@ -959,7 +963,7 @@ def main():
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device)):
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
|
||||
Reference in New Issue
Block a user