mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
move unwrap_model inside function
This commit is contained in:
@@ -108,12 +108,6 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def unwrap_model(accelerator, model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
|
||||
def log_validation(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
@@ -136,15 +130,12 @@ def log_validation(
|
||||
if vae is not None:
|
||||
pipeline_args["vae"] = vae
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder = unwrap_model(accelerator, text_encoder)
|
||||
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unwrap_model(accelerator, unet),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -939,6 +930,11 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1358,9 +1354,9 @@ def main(args):
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
images = log_validation(
|
||||
text_encoder,
|
||||
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
args,
|
||||
accelerator,
|
||||
|
||||
Reference in New Issue
Block a user