1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Pipeline to device (#210)

* Implement `pipeline.to(device)`

* DiffusionPipeline.to() decides best device on None.

* Breaking change: torch_device removed from __call__

`pipeline.to()` now has PyTorch semantics.

* Use kwargs and deprecation notice

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply torch_device compatibility to all pipelines.

* style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: anton-l <anton@huggingface.co>
This commit is contained in:
Pedro Cuenca
2022-08-19 18:39:08 +02:00
committed by GitHub
parent 89e9521048
commit 71ba8aec55
9 changed files with 146 additions and 59 deletions

View File

@@ -19,6 +19,8 @@ import inspect
import os
from typing import Optional, Union
import torch
from huggingface_hub import snapshot_download
from PIL import Image
@@ -113,6 +115,26 @@ class DiffusionPipeline(ConfigMixin):
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device)
return self
@property
def device(self) -> torch.device:
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""

View File

@@ -14,6 +14,8 @@
# limitations under the License.
import warnings
import torch
from tqdm.auto import tqdm
@@ -28,21 +30,28 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
self.unet.to(torch_device)
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# eta corresponds to η in paper and should be between [0, 1]
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
import warnings
import torch
from tqdm.auto import tqdm
@@ -28,18 +30,25 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
self.unet.to(torch_device)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(1000)

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Tuple, Union
import torch
@@ -31,13 +32,22 @@ class LDMTextToImagePipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str):
batch_size = 1
@@ -49,24 +59,20 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
import torch
@@ -14,22 +15,26 @@ class LDMPipeline(DiffusionPipeline):
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
self.unet.to(torch_device)
self.vqvae.to(torch_device)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
import warnings
import torch
from tqdm.auto import tqdm
@@ -28,20 +30,28 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):

View File

@@ -1,4 +1,6 @@
#!/usr/bin/env python3
import warnings
import torch
from diffusers import DiffusionPipeline
@@ -11,24 +13,32 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
model = self.unet.to(torch_device)
model = self.unet
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device)
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step
for _ in range(self.scheduler.correct_steps):

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union
import torch
@@ -45,11 +46,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str):
batch_size = 1
@@ -61,11 +71,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)
self.safety_checker.to(torch_device)
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
@@ -74,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -86,7 +91,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -97,7 +102,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=torch_device,
device=self.device,
)
# set timesteps
@@ -150,7 +155,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil":

View File

@@ -1,4 +1,6 @@
#!/usr/bin/env python3
import warnings
import torch
from tqdm.auto import tqdm
@@ -27,18 +29,27 @@ class KarrasVePipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
model = self.unet.to(torch_device)
model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device)
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)