mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
chore: add a cleaning utility to be useful during training. (#9240)
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
@@ -210,9 +210,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
return images
|
||||
|
||||
@@ -1107,9 +1105,7 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1455,12 +1451,10 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(
|
||||
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
)
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1795,11 +1789,11 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_objs_and_retain_memory(objs=objs)
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def clear_objs_and_retain_memory(objs: List[Any]):
|
||||
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
if len(objs) >= 1:
|
||||
for obj in objs:
|
||||
del obj
|
||||
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.empty_cache()
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user