1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

move unwrap_model inside function

This commit is contained in:
Pham Hong Vinh
2024-01-12 12:26:59 +07:00
parent c4468f5f01
commit 40b219135e

View File

@@ -108,12 +108,6 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
f.write(yaml + model_card)
def unwrap_model(accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def log_validation(
text_encoder,
tokenizer,
@@ -136,15 +130,12 @@ def log_validation(
if vae is not None:
pipeline_args["vae"] = vae
if text_encoder is not None:
text_encoder = unwrap_model(accelerator, text_encoder)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unwrap_model(accelerator, unet),
unet=unet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -939,6 +930,11 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1358,9 +1354,9 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation(
text_encoder,
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
tokenizer,
unet,
unwrap_model(unet),
vae,
args,
accelerator,