mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
support compile
This commit is contained in:
@@ -55,6 +55,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
@@ -106,6 +107,10 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
def unwrap_model(accelerator, model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
def log_validation(
|
||||
text_encoder,
|
||||
@@ -130,14 +135,14 @@ def log_validation(
|
||||
pipeline_args["vae"] = vae
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder = unwrap_model(accelerator, text_encoder)
|
||||
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(accelerator, unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -794,6 +799,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
@@ -935,7 +941,7 @@ def main(args):
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
sub_dir = "unet" if isinstance(model, type(unwrap_model(accelerator, unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
@@ -946,7 +952,7 @@ def main(args):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
if isinstance(model, type(unwrap_model(accelerator, text_encoder))):
|
||||
# load transformers style into model
|
||||
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
|
||||
model.config = load_model.config
|
||||
@@ -991,14 +997,14 @@ def main(args):
|
||||
" doing mixed precision training. copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).dtype != torch.float32:
|
||||
if unwrap_model(accelerator, unet).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
f"Unet loaded as datatype {unwrap_model(accelerator, unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
|
||||
if args.train_text_encoder and unwrap_model(accelerator, text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
|
||||
f"Text encoder loaded as datatype {unwrap_model(accelerator, text_encoder).dtype}."
|
||||
f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
@@ -1246,7 +1252,7 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
|
||||
if unwrap_model(accelerator, unet).config.in_channels == channels * 2:
|
||||
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
@@ -1256,8 +1262,8 @@ def main(args):
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
|
||||
).sample
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
|
||||
)[0]
|
||||
|
||||
if model_pred.shape[1] == 6:
|
||||
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
|
||||
@@ -1375,14 +1381,14 @@ def main(args):
|
||||
pipeline_args = {}
|
||||
|
||||
if text_encoder is not None:
|
||||
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
|
||||
pipeline_args["text_encoder"] = unwrap_model(accelerator, text_encoder)
|
||||
|
||||
if args.skip_save_text_encoder:
|
||||
pipeline_args["text_encoder"] = None
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(accelerator, unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
**pipeline_args,
|
||||
|
||||
Reference in New Issue
Block a user