1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix the issue on sd3 dreambooth w./w.t. lora training (#9419)

* Fix dtype error

* [bugfix] Fixed the issue on sd3 dreambooth training

* [bugfix] Fixed the issue on sd3 dreambooth training

---------

Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Leo Jiang
2024-09-14 18:59:38 +08:00
committed by GitHub
parent 48e36353d8
commit e2ead7cdcc
6 changed files with 24 additions and 6 deletions

View File

@@ -154,13 +154,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1717,6 +1718,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
@@ -1761,6 +1763,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -122,6 +122,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
@@ -141,7 +142,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1360,6 +1361,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
@@ -1402,6 +1404,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -170,13 +170,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1785,6 +1786,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
@@ -1832,6 +1834,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -179,13 +179,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1788,6 +1789,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
@@ -1840,6 +1842,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -180,6 +180,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
@@ -201,7 +202,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1890,6 +1891,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
@@ -1955,6 +1957,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -157,13 +157,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1725,6 +1726,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -1775,6 +1777,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub: