mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
guard save model hooks to only execute on main process (#4929)
This commit is contained in:
@@ -785,16 +785,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -840,16 +840,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -920,12 +920,13 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -894,27 +894,28 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -798,31 +798,32 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -485,14 +485,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -528,14 +528,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -1010,16 +1010,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -552,14 +552,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -313,14 +313,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -629,14 +629,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -669,31 +669,32 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -651,14 +651,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -309,14 +309,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
Reference in New Issue
Block a user