mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix the issue on sd3 dreambooth w./w.t. lora training (#9419)
* Fix dtype error * [bugfix] Fixed the issue on sd3 dreambooth training * [bugfix] Fixed the issue on sd3 dreambooth training --------- Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -154,13 +154,14 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1717,6 +1718,7 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
@@ -1761,6 +1763,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -122,6 +122,7 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
@@ -141,7 +142,7 @@ def log_validation(
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1360,6 +1361,7 @@ def main(args):
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# Save the lora layers
|
||||
@@ -1402,6 +1404,7 @@ def main(args):
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -170,13 +170,14 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1785,6 +1786,7 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
@@ -1832,6 +1834,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -179,13 +179,14 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1788,6 +1789,7 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
@@ -1840,6 +1842,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -180,6 +180,7 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
@@ -201,7 +202,7 @@ def log_validation(
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1890,6 +1891,7 @@ def main(args):
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# Save the lora layers
|
||||
@@ -1955,6 +1957,7 @@ def main(args):
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -157,13 +157,14 @@ def log_validation(
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
torch_dtype,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -1725,6 +1726,7 @@ def main(args):
|
||||
accelerator=accelerator,
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
@@ -1775,6 +1777,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
is_final_validation=True,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user