mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Model offload] Add nice warning (#2543)
* [Model offload] Add nice warning * Treat sequential and model offload differently. Sequential raises an error because the operation would fail with a cryptic warning later. * Forcibly move to cpu when offloading. * make style * one more fix * make fix-copies * up --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
4f0141a67d
commit
5b6582cf73
@@ -214,6 +214,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -234,6 +238,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -220,6 +220,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -240,6 +244,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -49,6 +49,7 @@ from ..utils import (
|
||||
get_class_from_dynamic_module,
|
||||
http_user_agent,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
@@ -66,6 +67,10 @@ if is_transformers_available():
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||
@@ -337,15 +342,50 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
|
||||
if torch_device is None:
|
||||
return self
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
def module_is_sequentially_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
|
||||
@@ -237,6 +237,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -258,6 +262,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -217,6 +217,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -237,6 +241,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -263,6 +263,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -225,6 +225,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -246,6 +250,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -272,6 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -293,6 +297,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -216,6 +216,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -237,6 +241,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -405,6 +405,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -426,6 +430,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -137,6 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -158,6 +162,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -158,6 +158,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -372,6 +372,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -176,6 +176,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -584,6 +584,42 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert image_img2img.shape == (1, 32, 32, 3)
|
||||
assert image_text2img.shape == (1, 64, 64, 3)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_pipe_false_offload_warn(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
sd.enable_model_cpu_offload()
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
sd.to("cuda")
|
||||
|
||||
assert "It is strongly recommended against doing so" in str(cap_logger)
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
def test_set_scheduler(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
|
||||
Reference in New Issue
Block a user