1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

change unwrap call

This commit is contained in:
Pham Hong Vinh
2024-01-12 12:32:33 +07:00
parent 40b219135e
commit 4191fc30f4

View File

@@ -939,7 +939,7 @@ def main(args):
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "unet" if isinstance(model, type(unwrap_model(accelerator, unet))) else "text_encoder"
sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
@@ -950,7 +950,7 @@ def main(args):
# pop models so that they are not loaded again
model = models.pop()
if isinstance(model, type(unwrap_model(accelerator, text_encoder))):
if isinstance(model, type(unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
@@ -995,14 +995,14 @@ def main(args):
" doing mixed precision training. copy of the weights should still be float32."
)
if unwrap_model(accelerator, unet).dtype != torch.float32:
if unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {unwrap_model(accelerator, unet).dtype}. {low_precision_error_string}"
f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if args.train_text_encoder and unwrap_model(accelerator, text_encoder).dtype != torch.float32:
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(accelerator, text_encoder).dtype}."
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
@@ -1250,7 +1250,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if unwrap_model(accelerator, unet).config.in_channels == channels * 2:
if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
@@ -1379,14 +1379,14 @@ def main(args):
pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = unwrap_model(accelerator, text_encoder)
pipeline_args["text_encoder"] = unwrap_model(text_encoder)
if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrap_model(accelerator, unet),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,