From 5b6582cf73ca71cb8a7bef0c608ecf10399916bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 16:44:10 +0100 Subject: [PATCH] [Model offload] Add nice warning (#2543) * [Model offload] Add nice warning * Treat sequential and model offload differently. Sequential raises an error because the operation would fail with a cryptic warning later. * Forcibly move to cpu when offloading. * make style * one more fix * make fix-copies * up --------- Co-authored-by: Pedro Cuenca --- .../alt_diffusion/pipeline_alt_diffusion.py | 8 ++++ .../pipeline_alt_diffusion_img2img.py | 8 ++++ src/diffusers/pipelines/pipeline_utils.py | 44 ++++++++++++++++++- .../pipeline_cycle_diffusion.py | 8 ++++ .../pipeline_stable_diffusion.py | 8 ++++ ...line_stable_diffusion_attend_and_excite.py | 4 ++ .../pipeline_stable_diffusion_img2img.py | 8 ++++ .../pipeline_stable_diffusion_inpaint.py | 8 ++++ ...ipeline_stable_diffusion_inpaint_legacy.py | 8 ++++ ...eline_stable_diffusion_instruct_pix2pix.py | 8 ++++ .../pipeline_stable_diffusion_k_diffusion.py | 8 ++++ .../pipeline_stable_diffusion_panorama.py | 4 ++ .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++ .../pipeline_stable_diffusion_sag.py | 4 ++ tests/test_pipelines.py | 36 +++++++++++++++ 15 files changed, 166 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 2ae3baa74d..71e98480ed 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -214,6 +214,10 @@ class AltDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -234,6 +238,10 @@ class AltDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 1b30ceddca..1e7872e3b0 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -220,6 +220,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -240,6 +244,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6bd231853f..65b348d2e7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -49,6 +49,7 @@ from ..utils import ( get_class_from_dynamic_module, http_user_agent, is_accelerate_available, + is_accelerate_version, is_safetensors_available, is_torch_version, is_transformers_available, @@ -66,6 +67,10 @@ if is_transformers_available(): from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME +if is_accelerate_available(): + import accelerate + + INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" @@ -337,15 +342,50 @@ class DiffusionPipeline(ConfigMixin): save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) - def to(self, torch_device: Optional[Union[str, torch.device]] = None): + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): if torch_device is None: return self + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + module_names, _, _ = self.extract_init_dict(dict(self.config)) + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: + if ( + module.dtype == torch.float16 + and str(torch_device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): logger.warning( "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 385c3e09bb..e977071b9c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -237,6 +237,10 @@ class CycleDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -258,6 +262,10 @@ class CycleDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 48d13511cd..5044797986 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -217,6 +217,10 @@ class StableDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -237,6 +241,10 @@ class StableDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 2e92c7c31c..e5550f1d0d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -263,6 +263,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ebc252430d..172ab15a75 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -225,6 +225,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -246,6 +250,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9b207b606..b645ba667f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -272,6 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -293,6 +297,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index a89d6814ca..1e84efa416 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -216,6 +216,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -237,6 +241,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 9820473471..9efa91e161 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -405,6 +405,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -426,6 +430,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index f668d152a7..f3db54caa3 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -137,6 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -158,6 +162,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 5f0ca6f67e..b1f29fbef1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -158,6 +158,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 101acf0ad9..c78d327f69 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -372,6 +372,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 4ba3345f94..0a34499e09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -176,6 +176,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e909e45613..c1641b4f35 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -584,6 +584,42 @@ class PipelineFastTests(unittest.TestCase): assert image_img2img.shape == (1, 32, 32, 3) assert image_text2img.shape == (1, 64, 64, 3) + @require_torch_gpu + def test_pipe_false_offload_warn(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.enable_model_cpu_offload() + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + with CaptureLogger(logger) as cap_logger: + sd.to("cuda") + + assert "It is strongly recommended against doing so" in str(cap_logger) + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + def test_set_scheduler(self): unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True)