1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make Dreambooth SD Training Script torch.compile compatible (#6532)

* support compile

* make style

* move unwrap_model inside function

* change unwrap call

* run make style

* Update examples/dreambooth/train_dreambooth.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Revert "Update examples/dreambooth/train_dreambooth.py"

This reverts commit 70ab09732e.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Vinh H. Pham
2024-01-12 14:20:15 +07:00
committed by GitHub
parent 33d2b5b087
commit 7d631825b0

View File

@@ -55,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
@@ -129,15 +130,12 @@ def log_validation(
if vae is not None:
pipeline_args["vae"] = vae
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
unet=unet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
return_dict=False,
)
prompt_embeds = prompt_embeds[0]
@@ -931,11 +930,16 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
@@ -946,7 +950,7 @@ def main(args):
# pop models so that they are not loaded again
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
if isinstance(model, type(unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
@@ -991,15 +995,12 @@ def main(args):
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if unwrap_model(unet).dtype != torch.float32:
raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
@@ -1246,7 +1247,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
@@ -1256,8 +1257,8 @@ def main(args):
# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
)[0]
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
@@ -1350,9 +1351,9 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation(
text_encoder,
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
tokenizer,
unet,
unwrap_model(unet),
vae,
args,
accelerator,
@@ -1375,14 +1376,14 @@ def main(args):
pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
pipeline_args["text_encoder"] = unwrap_model(text_encoder)
if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,