From d78acdedc1d248406455c6abcfa4224cd59f9ae7 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 28 Mar 2024 02:56:18 -0600 Subject: [PATCH] apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) (#7447) * apple mps: training support for SDXL LoRA * sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps --------- Co-authored-by: bghira Co-authored-by: Sayak Paul --- examples/controlnet/train_controlnet_sdxl.py | 12 +++++- .../dreambooth/train_dreambooth_lora_sdxl.py | 39 ++++++++++++++----- .../train_instruct_pix2pix_sdxl.py | 25 ++++++++---- .../train_text_to_image_lora_sdxl.py | 18 ++++++++- .../text_to_image/train_text_to_image_sdxl.py | 20 +++++++++- 5 files changed, 94 insertions(+), 20 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 47ed405af3..b602805235 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ) image_logs = [] - inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + inference_ctx = ( + contextlib.nullcontext() + if (is_final_validation or torch.backends.mps.is_available()) + else torch.autocast("cuda") + ) for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") @@ -792,6 +796,12 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 364e1423f0..1da83ff731 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import contextlib import gc import itertools import json @@ -208,11 +207,18 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - inference_ctx = ( - contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() - ) + enable_autocast = True + if torch.backends.mps.is_available() or ( + accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" + ): + enable_autocast = False + if "playground" in args.pretrained_model_name_or_path: + enable_autocast = False - with inference_ctx: + with torch.autocast( + accelerator.device.type, + enabled=enable_autocast, + ): images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: @@ -230,7 +236,8 @@ def log_validation( ) del pipeline - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return images @@ -967,6 +974,12 @@ def main(args): if args.do_edm_style_training and args.snr_gamma is not None: raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -1009,7 +1022,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1134,6 +1148,12 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) @@ -1278,7 +1298,7 @@ def main(args): # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: + if args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: @@ -1455,7 +1475,8 @@ def main(args): if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 36517e8ff6..aff279963a 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -71,12 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": tor def log_validation( - pipeline, - args, - accelerator, - generator, - global_step, - is_final_validation=False, + pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -96,7 +91,7 @@ def log_validation( else Image.open(image_url_or_path).convert("RGB") )(args.val_image_url_or_path) - with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast(accelerator.device.type, enabled=enable_autocast): edited_images = [] # Run inference for val_img_idx in range(args.num_validation_images): @@ -497,6 +492,13 @@ def main(): ), ) logging_dir = os.path.join(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -981,6 +983,13 @@ def main(): if accelerator.is_main_process: accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args)) + # Some configurations require autocast to be disabled. + enable_autocast = True + if torch.backends.mps.is_available() or ( + accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" + ): + enable_autocast = False + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1193,6 +1202,7 @@ def main(): generator, global_step, is_final_validation=False, + enable_autocast=enable_autocast, ) if args.use_ema: @@ -1242,6 +1252,7 @@ def main(): generator, global_step, is_final_validation=True, + enable_autocast=enable_autocast, ) accelerator.end_training() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index f1d8e1b093..c9860b744c 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -501,6 +501,12 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( @@ -973,6 +979,13 @@ def main(args): if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + # Some configurations require autocast to be disabled. + enable_autocast = True + if torch.backends.mps.is_available() or ( + accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" + ): + enable_autocast = False + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1199,7 +1212,10 @@ def main(args): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.cuda.amp.autocast(): + with torch.autocast( + accelerator.device.type, + enabled=enable_autocast, + ): images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index cb1feb806c..c141f5bdd7 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -590,6 +590,12 @@ def main(args): accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -980,6 +986,13 @@ def main(args): model = model._orig_mod if is_compiled_module(model) else model return model + # Some configurations require autocast to be disabled. + enable_autocast = True + if torch.backends.mps.is_available() or ( + accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" + ): + enable_autocast = False + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1213,7 +1226,10 @@ def main(args): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.cuda.amp.autocast(): + with torch.autocast( + accelerator.device.type, + enabled=enable_autocast, + ): images = [ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] for _ in range(args.num_validation_images) @@ -1268,7 +1284,7 @@ def main(args): if args.validation_prompt and args.num_validation_images > 0: pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - with torch.cuda.amp.autocast(): + with torch.autocast(accelerator.device.type, enabled=enable_autocast): images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] for _ in range(args.num_validation_images)