mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make ControlNet SD Training Script torch.compile compatible (#6525)
* update: make controlnet script torch compile compatible Signed-off-by: Suvaditya Mukherjee <suvadityamuk@gmail.com> * update: correct earlier mistakes for compilation Signed-off-by: Suvaditya Mukherjee <suvadityamuk@gmail.com> * update: fix code style issues Signed-off-by: Suvaditya Mukherjee <suvadityamuk@gmail.com> --------- Signed-off-by: Suvaditya Mukherjee <suvadityamuk@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e44b205e0b
commit
f486d34b04
@@ -50,6 +50,7 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
@@ -787,6 +788,12 @@ def main(args):
|
||||
logger.info("Initializing controlnet weights from unet")
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -846,9 +853,9 @@ def main(args):
|
||||
" doing mixed precision training, copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
|
||||
if unwrap_model(controlnet).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -1015,7 +1022,7 @@ def main(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
|
||||
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -1036,7 +1043,8 @@ def main(args):
|
||||
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
||||
],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@@ -1109,7 +1117,7 @@ def main(args):
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
controlnet = unwrap_model(controlnet)
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user