mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Pipeline to device (#210)
* Implement `pipeline.to(device)` * DiffusionPipeline.to() decides best device on None. * Breaking change: torch_device removed from __call__ `pipeline.to()` now has PyTorch semantics. * Use kwargs and deprecation notice Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply torch_device compatibility to all pipelines. * style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: anton-l <anton@huggingface.co>
This commit is contained in:
@@ -19,6 +19,8 @@ import inspect
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
@@ -113,6 +115,26 @@ class DiffusionPipeline(ConfigMixin):
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
if torch_device is None:
|
||||
return self
|
||||
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module.to(torch_device)
|
||||
return self
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -28,21 +30,28 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
|
||||
|
||||
self.unet.to(torch_device)
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -28,18 +30,25 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(1000)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -31,13 +32,22 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
guidance_scale: Optional[float] = 1.0,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -49,24 +59,20 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
|
||||
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,22 +15,26 @@ class LDMPipeline(DiffusionPipeline):
|
||||
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
|
||||
):
|
||||
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -28,20 +30,28 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
|
||||
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
image = image.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -11,24 +13,32 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
model = self.unet
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(torch_device)
|
||||
sample = sample.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
|
||||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -45,11 +46,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -61,11 +71,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vae.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
self.safety_checker.to(torch_device)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
@@ -74,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -86,7 +91,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -97,7 +102,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
device=torch_device,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
@@ -150,7 +155,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -27,18 +29,27 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
model = self.unet
|
||||
|
||||
# sample x_0 ~ N(0, sigma_0^2 * I)
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(torch_device)
|
||||
sample = sample.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user