1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into cogvideox-lora-and-training

This commit is contained in:
Aryan
2024-09-14 23:44:05 +02:00
6 changed files with 24 additions and 6 deletions

View File

@@ -154,13 +154,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1717,6 +1718,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
@@ -1761,6 +1763,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -122,6 +122,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
@@ -141,7 +142,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1360,6 +1361,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
@@ -1402,6 +1404,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -170,13 +170,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1785,6 +1786,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
@@ -1832,6 +1834,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -179,13 +179,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1788,6 +1789,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
@@ -1840,6 +1842,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -180,6 +180,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
@@ -201,7 +202,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1890,6 +1891,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
@@ -1955,6 +1957,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:

View File

@@ -157,13 +157,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
@@ -1725,6 +1726,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -1775,6 +1777,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub: