mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Enable dreambooth lora finetune example on other devices (#10602)
* enable dreambooth_lora on other devices Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable xpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check cuda device before empty cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * import free_memory Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -54,7 +54,11 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
@@ -151,14 +155,14 @@ def log_validation(
|
||||
if args.validation_images is None:
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(accelerator.device.type):
|
||||
image = pipeline(**pipeline_args, generator=generator).images[0]
|
||||
images.append(image)
|
||||
else:
|
||||
images = []
|
||||
for image in args.validation_images:
|
||||
image = Image.open(image)
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(accelerator.device.type):
|
||||
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -177,7 +181,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -793,7 +797,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
@@ -829,8 +833,7 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1085,7 +1088,7 @@ def main(args):
|
||||
tokenizer = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
else:
|
||||
pre_computed_encoder_hidden_states = None
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
@@ -299,6 +299,8 @@ def free_memory():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
|
||||
Reference in New Issue
Block a user