mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve the performance and suitable for NPU computing (#9642)
* Improve the performance and suitable for NPU * Improve the performance and suitable for NPU computing * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU --------- Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -59,6 +59,8 @@ check_min_version("0.31.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
|
||||
with torch.no_grad():
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
|
||||
# There might have slightly performance improvement
|
||||
# by changing model_input.cpu() to accelerator.gather(model_input)
|
||||
return {"model_input": model_input.cpu()}
|
||||
|
||||
|
||||
@@ -935,7 +940,10 @@ def main(args):
|
||||
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
|
||||
del text_encoders, tokenizers, vae
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def collate_fn(examples):
|
||||
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
|
||||
@@ -1091,8 +1099,7 @@ def main(args):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
target_size = (args.resolution, args.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
|
||||
return add_time_ids
|
||||
|
||||
add_time_ids = torch.cat(
|
||||
@@ -1261,7 +1268,10 @@ def main(args):
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
if is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
|
||||
Reference in New Issue
Block a user