mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Update for multi-gpu trianing.
This commit is contained in:
@@ -43,7 +43,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth.py \
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
@@ -64,7 +64,7 @@ export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth.py \
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
|
||||
@@ -345,6 +345,7 @@ def main():
|
||||
sd_model = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token
|
||||
)
|
||||
sd_model.set_progress_bar_config(disable=True)
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
@@ -441,10 +442,14 @@ def main():
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Move text_encode and vae to gpu
|
||||
text_encoder.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
# Keep text_encoder and vae in eval model as we don't train it
|
||||
text_encoder.eval()
|
||||
vae.eval()
|
||||
@@ -536,8 +541,8 @@ def main():
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
|
||||
Reference in New Issue
Block a user