1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

chore: add a cleaning utility to be useful during training. (#9240)

This commit is contained in:
Sayak Paul
2024-09-03 15:00:17 +05:30
committed by GitHub
parent 9d49b45b19
commit 8ba90aa706
2 changed files with 26 additions and 15 deletions

View File

@@ -15,7 +15,6 @@
import argparse
import copy
import gc
import itertools
import logging
import math
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
@@ -210,9 +210,7 @@ def log_validation(
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
return images
@@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
# Handle the repository creation
if accelerator.is_main_process:
@@ -1455,12 +1451,10 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1795,11 +1789,11 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory(objs=objs)
# Save the lora layers
accelerator.wait_for_everyone()

View File

@@ -1,5 +1,6 @@
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""