1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add self.image_encoder, self.text_decoder to list of models to offload to CPU in the enable_sequential_cpu_offload(...)/enable_model_cpu_offload(...) methods to make test_cpu_offload_forward_pass pass.

This commit is contained in:
Daniel Gu
2023-05-05 12:12:42 -07:00
parent 1cb726a005
commit e62b32aadc

View File

@@ -162,7 +162,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
# TODO: handle safety checking?
self.safety_checker = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -182,13 +183,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -208,7 +210,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None: