mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) (#7447)
* apple mps: training support for SDXL LoRA * sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext()
|
||||
if (is_final_validation or torch.backends.mps.is_available())
|
||||
else torch.autocast("cuda")
|
||||
)
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
@@ -792,6 +796,12 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
@@ -208,11 +207,18 @@ def log_validation(
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
if "playground" in args.pretrained_model_name_or_path:
|
||||
enable_autocast = False
|
||||
|
||||
with inference_ctx:
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -230,7 +236,8 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
@@ -967,6 +974,12 @@ def main(args):
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -1009,7 +1022,8 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
|
||||
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
@@ -1134,6 +1148,12 @@ def main(args):
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -1278,7 +1298,7 @@ def main(args):
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
if args.allow_tf32 and torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
@@ -1455,7 +1475,8 @@ def main(args):
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
|
||||
@@ -71,12 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": tor
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
@@ -96,7 +91,7 @@ def log_validation(
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
|
||||
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
|
||||
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
|
||||
edited_images = []
|
||||
# Run inference
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
@@ -497,6 +492,13 @@ def main():
|
||||
),
|
||||
)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
@@ -981,6 +983,13 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1193,6 +1202,7 @@ def main():
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
enable_autocast=enable_autocast,
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
@@ -1242,6 +1252,7 @@ def main():
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=True,
|
||||
enable_autocast=enable_autocast,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
@@ -501,6 +501,12 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
@@ -973,6 +979,13 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1199,7 +1212,10 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -590,6 +590,12 @@ def main(args):
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -980,6 +986,13 @@ def main(args):
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1213,7 +1226,10 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
@@ -1268,7 +1284,7 @@ def main(args):
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
Reference in New Issue
Block a user