1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve the performance and suitable for NPU computing (#9642)

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU computing

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

---------

Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Leo Jiang
2024-10-14 10:09:33 -06:00
committed by GitHub
parent 8d81564b27
commit 5956b68a69

View File

@@ -59,6 +59,8 @@ check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = {
@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
# There might have slightly performance improvement
# by changing model_input.cpu() to accelerator.gather(model_input)
return {"model_input": model_input.cpu()}
@@ -935,7 +940,10 @@ def main(args):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1091,8 +1099,7 @@ def main(args):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
@@ -1261,7 +1268,10 @@ def main(args):
)
del pipeline
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.