1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Update for multi-gpu trianing.

This commit is contained in:
Zhenhuan Liu
2022-09-23 04:31:15 -04:00
parent 5bb534b0a5
commit faffe23627
2 changed files with 11 additions and 6 deletions

View File

@@ -43,7 +43,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth.py \
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
@@ -64,7 +64,7 @@ export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth.py \
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \

View File

@@ -345,6 +345,7 @@ def main():
sd_model = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token
)
sd_model.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
@@ -441,10 +442,14 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# Move text_encode and vae to gpu
text_encoder.to(accelerator.device)
vae.to(accelerator.device)
# Keep text_encoder and vae in eval model as we don't train it
text_encoder.eval()
vae.eval()
@@ -536,8 +541,8 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
text_encoder=text_encoder,
vae=vae,
unet=accelerator.unwrap_model(unet),
tokenizer=tokenizer,
scheduler=PNDMScheduler(