From 71ba8aec55b52a7ba5a1ff1db1265ffdd3c65ea2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 19 Aug 2022 18:39:08 +0200 Subject: [PATCH] Pipeline to device (#210) * Implement `pipeline.to(device)` * DiffusionPipeline.to() decides best device on None. * Breaking change: torch_device removed from __call__ `pipeline.to()` now has PyTorch semantics. * Use kwargs and deprecation notice Co-authored-by: Patrick von Platen * Apply torch_device compatibility to all pipelines. * style Co-authored-by: Patrick von Platen Co-authored-by: anton-l --- src/diffusers/pipeline_utils.py | 22 ++++++++++++++ src/diffusers/pipelines/ddim/pipeline_ddim.py | 25 +++++++++++----- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 19 ++++++++---- .../pipeline_latent_diffusion.py | 26 ++++++++++------- .../pipeline_latent_diffusion_uncond.py | 21 +++++++++----- src/diffusers/pipelines/pndm/pipeline_pndm.py | 20 +++++++++---- .../score_sde_ve/pipeline_score_sde_ve.py | 22 ++++++++++---- .../pipeline_stable_diffusion.py | 29 +++++++++++-------- .../pipeline_stochastic_karras_ve.py | 21 ++++++++++---- 9 files changed, 146 insertions(+), 59 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5b781f0e09..6cb98d7c9b 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -19,6 +19,8 @@ import inspect import os from typing import Optional, Union +import torch + from huggingface_hub import snapshot_download from PIL import Image @@ -113,6 +115,26 @@ class DiffusionPipeline(ConfigMixin): save_method = getattr(sub_model, save_method_name) save_method(os.path.join(save_directory, pipeline_component_name)) + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index a1000ae2ef..700e2b9ca3 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -14,6 +14,8 @@ # limitations under the License. +import warnings + import torch from tqdm.auto import tqdm @@ -28,21 +30,28 @@ class DDIMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__( - self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil" - ): - # eta corresponds to η in paper and should be between [0, 1] - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs): - self.unet.to(torch_device) + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + # eta corresponds to η in paper and should be between [0, 1] # Sample gaussian noise to begin loop image = torch.randn( (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) - image = image.to(torch_device) + image = image.to(self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index bab1c245f3..099add5daa 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -14,6 +14,8 @@ # limitations under the License. +import warnings + import torch from tqdm.auto import tqdm @@ -28,18 +30,25 @@ class DDPMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"): - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs): + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) - self.unet.to(torch_device) + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) # Sample gaussian noise to begin loop image = torch.randn( (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) - image = image.to(torch_device) + image = image.to(self.device) # set step values self.scheduler.set_timesteps(1000) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 1edcdadb22..17a15aca18 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Tuple, Union import torch @@ -31,13 +32,22 @@ class LDMTextToImagePipeline(DiffusionPipeline): guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, - torch_device: Optional[Union[str, torch.device]] = None, output_type: Optional[str] = "pil", + **kwargs, ): # eta corresponds to η in paper and should be between [0, 1] - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) if isinstance(prompt, str): batch_size = 1 @@ -49,24 +59,20 @@ class LDMTextToImagePipeline(DiffusionPipeline): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - self.unet.to(torch_device) - self.vqvae.to(torch_device) - self.bert.to(torch_device) - # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") - uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0] + uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] # get prompt text embeddings text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") - text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0] + text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, ) - latents = latents.to(torch_device) + latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index c4f3337a0b..bdff4fc948 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,4 +1,5 @@ import inspect +import warnings import torch @@ -14,22 +15,26 @@ class LDMPipeline(DiffusionPipeline): self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__( - self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil" - ): + def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs): # eta corresponds to η in paper and should be between [0, 1] - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) - self.unet.to(torch_device) - self.vqvae.to(torch_device) + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) - latents = latents.to(torch_device) + latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 7b8946529c..bc0f75648a 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -14,6 +14,8 @@ # limitations under the License. +import warnings + import torch from tqdm.auto import tqdm @@ -28,20 +30,28 @@ class PNDMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"): + def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs): # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" - self.unet.to(torch_device) + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) # Sample gaussian noise to begin loop image = torch.randn( (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) - image = image.to(torch_device) + image = image.to(self.device) self.scheduler.set_timesteps(num_inference_steps) for t in tqdm(self.scheduler.timesteps): diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 939ac9ec27..884f1894f7 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import warnings + import torch from diffusers import DiffusionPipeline @@ -11,24 +13,32 @@ class ScoreSdeVePipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"): + def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs): + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) - model = self.unet.to(torch_device) + model = self.unet sample = torch.randn(*shape) * self.scheduler.config.sigma_max - sample = sample.to(torch_device) + sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) for i, t in tqdm(enumerate(self.scheduler.timesteps)): - sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device) + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step for _ in range(self.scheduler.correct_steps): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index baff1db970..550513b5c9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import torch @@ -45,11 +46,20 @@ class StableDiffusionPipeline(DiffusionPipeline): guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, - torch_device: Optional[Union[str, torch.device]] = None, output_type: Optional[str] = "pil", + **kwargs, ): - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) if isinstance(prompt, str): batch_size = 1 @@ -61,11 +71,6 @@ class StableDiffusionPipeline(DiffusionPipeline): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - self.unet.to(torch_device) - self.vae.to(torch_device) - self.text_encoder.to(torch_device) - self.safety_checker.to(torch_device) - # get prompt text embeddings text_input = self.tokenizer( prompt, @@ -74,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -86,7 +91,7 @@ class StableDiffusionPipeline(DiffusionPipeline): uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -97,7 +102,7 @@ class StableDiffusionPipeline(DiffusionPipeline): latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, - device=torch_device, + device=self.device, ) # set timesteps @@ -150,7 +155,7 @@ class StableDiffusionPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 25d85126fd..ebf95e6663 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import warnings + import torch from tqdm.auto import tqdm @@ -27,18 +29,27 @@ class KarrasVePipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"): - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs): + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) - model = self.unet.to(torch_device) + model = self.unet # sample x_0 ~ N(0, sigma_0^2 * I) sample = torch.randn(*shape) * self.scheduler.config.sigma_max - sample = sample.to(torch_device) + sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps)