mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[chore] fix: retain memory utility. (#9543)
* fix: retain memory utility. * fix * quality * free_memory.
This commit is contained in:
@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
from diffusers.training_utils import (
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
)
|
||||
from diffusers.training_utils import cast_training_params, free_memory
|
||||
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -726,7 +723,8 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
clear_objs_and_retain_memory([pipe])
|
||||
del pipe
|
||||
free_memory()
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
from diffusers.models.controlnet_flux import FluxControlNetModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
@@ -193,7 +193,8 @@ def log_validation(
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
clear_objs_and_retain_memory([pipeline])
|
||||
del pipeline
|
||||
free_memory()
|
||||
return image_logs
|
||||
|
||||
|
||||
@@ -1103,7 +1104,8 @@ def main(args):
|
||||
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
|
||||
)
|
||||
|
||||
clear_objs_and_retain_memory([text_encoders, tokenizers])
|
||||
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
# Then get the training dataset ready to be passed to the dataloader.
|
||||
train_dataset = prepare_train_dataset(train_dataset, accelerator)
|
||||
|
||||
@@ -49,11 +49,7 @@ from diffusers import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
clear_objs_and_retain_memory(pipeline)
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
if not is_final_validation:
|
||||
controlnet.to(accelerator.device)
|
||||
@@ -1131,7 +1128,9 @@ def main(args):
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
|
||||
clear_objs_and_retain_memory(text_encoders + tokenizers)
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
free_memory()
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -1437,7 +1437,8 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1480,7 +1481,8 @@ def main(args):
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
clear_objs_and_retain_memory([vae])
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1817,7 +1819,8 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -211,7 +211,8 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
return images
|
||||
|
||||
@@ -1106,7 +1107,8 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
del pipeline
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1453,9 +1455,9 @@ def main(args):
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
clear_objs_and_retain_memory(
|
||||
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
)
|
||||
del tokenizers, text_encoders
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1791,11 +1793,9 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
|
||||
|
||||
clear_objs_and_retain_memory(objs=objs)
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def clear_objs_and_retain_memory(objs: List[Any]):
|
||||
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
if len(objs) >= 1:
|
||||
for obj in objs:
|
||||
del obj
|
||||
|
||||
def free_memory():
|
||||
"""Runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
||||
Reference in New Issue
Block a user