mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Modularize train_text_to_image_lora SD inferencing during and after training in example (#8283)
* Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -52,6 +52,9 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
|
||||
@@ -99,6 +102,48 @@ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
epoch,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
phase_name: [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -414,11 +459,6 @@ def main():
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -864,10 +904,6 @@ def main():
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -876,38 +912,7 @@ def main():
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
images = log_validation(pipeline, args, accelerator, epoch)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -925,6 +930,22 @@ def main():
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
if args.validation_prompt is not None:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
@@ -940,51 +961,6 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
if args.validation_prompt is not None:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if len(images) != 0:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user