1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
This commit is contained in:
Patrick von Platen
2023-04-11 13:35:42 +02:00
committed by GitHub
parent 8369196703
commit 8b451eb63b
66 changed files with 221 additions and 105 deletions

View File

@@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
... # Sample a random timestep for each image
... timesteps = torch.randint(
... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
... ).long()
... # Add noise to the clean images according to the noise magnitude at each timestep

View File

@@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
)
timestep = 1
@@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
)
timestep = 1

View File

@@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
# Sample gaussian noise to begin loop
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
image = torch.randn(
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
)
image = image.to(self.device)

View File

@@ -238,7 +238,7 @@ class BitDiffusion(DiffusionPipeline):
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
latents = torch.randn(
(batch_size, self.unet.in_channels, height, width),
(batch_size, self.unet.config.in_channels, height, width),
generator=generator,
)
latents = decimal_to_bits(latents) * self.bit_scale

View File

@@ -254,7 +254,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -414,7 +414,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -513,7 +513,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -424,7 +424,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if self.device.type == "mps":
# randn does not exist on mps

View File

@@ -320,7 +320,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
@@ -416,7 +416,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
"""Takes in random seed and returns corresponding noise vector"""
return torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
(1, self.unet.config.in_channels, height // 8, width // 8),
generator=torch.Generator(device=self.device).manual_seed(seed),
device=self.device,
dtype=dtype,

View File

@@ -627,7 +627,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
if image is None:
shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)

View File

@@ -486,7 +486,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
self.__init__additional__()
def __init__additional__(self):
self.unet_in_channels = 4
self.unet.config.in_channels = 4
self.vae_scale_factor = 8
def _encode_prompt(
@@ -621,7 +621,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
if image is None:
shape = (
batch_size,
self.unet_in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)

View File

@@ -93,7 +93,7 @@ class MagicMixPipeline(DiffusionPipeline):
torch.manual_seed(seed)
noise = torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
(1, self.unet.config.in_channels, height // 8, width // 8),
).to(self.device)
latents = self.scheduler.add_noise(

View File

@@ -355,7 +355,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -433,7 +433,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
sigmas = sigmas.to(text_embeddings.dtype)
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -262,8 +262,8 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -190,7 +190,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -337,7 +337,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@@ -794,7 +794,7 @@ def main():
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep

View File

@@ -794,7 +794,7 @@ def main():
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep

View File

@@ -641,7 +641,7 @@ def main():
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep

View File

@@ -804,7 +804,7 @@ def main():
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep

View File

@@ -707,7 +707,7 @@ def main():
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep

View File

@@ -109,13 +109,6 @@ class ConfigMixin:
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:

View File

@@ -99,8 +99,8 @@ class VaeImageProcessor(ConfigMixin):
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images
def preprocess(
@@ -119,7 +119,7 @@ class VaeImageProcessor(ConfigMixin):
)
if isinstance(image[0], PIL.Image.Image):
if self.do_resize:
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
@@ -129,23 +129,27 @@ class VaeImageProcessor(ConfigMixin):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize
do_normalize = self.config.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "

View File

@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook
from ..utils import BaseOutput, apply_forward_hook, deprecate
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -120,9 +120,19 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

View File

@@ -19,7 +19,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@@ -190,6 +190,16 @@ class UNet1DModel(ModelMixin, ConfigMixin):
fc_dim=block_out_channels[-1] // 4,
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@@ -215,6 +215,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from ..utils import BaseOutput, deprecate, logging
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
@@ -412,6 +412,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""

View File

@@ -646,7 +646,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -121,17 +121,17 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(steps)
step_generator = step_generator or generator
# For backwards compatibility
if type(self.unet.sample_size) == int:
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
if type(self.unet.config.sample_size) == int:
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None:
noise = randn_tensor(
(
batch_size,
self.unet.in_channels,
self.unet.sample_size[0],
self.unet.sample_size[1],
self.unet.config.in_channels,
self.unet.config.sample_size[0],
self.unet.config.sample_size[1],
),
generator=generator,
device=self.device,
@@ -158,7 +158,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
pixels_per_second = (
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
)
mask_start = int(mask_start_secs * pixels_per_second)
mask_end = int(mask_end_secs * pixels_per_second)

View File

@@ -540,7 +540,7 @@ class AudioLDMPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_latents,

View File

@@ -61,7 +61,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
to make generation deterministic.
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
`sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
@@ -73,27 +73,29 @@ class DanceDiffusionPipeline(DiffusionPipeline):
if audio_length_in_s is None:
audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
sample_size = audio_length_in_s * self.unet.sample_rate
sample_size = audio_length_in_s * self.unet.config.sample_rate
down_scale_factor = 2 ** len(self.unet.up_blocks)
if sample_size < 3 * down_scale_factor:
raise ValueError(
f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
f" {3 * down_scale_factor / self.unet.sample_rate}."
f" {3 * down_scale_factor / self.unet.config.sample_rate}."
)
original_sample_size = int(sample_size)
if sample_size % down_scale_factor != 0:
sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
sample_size = (
(audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
) * down_scale_factor
logger.info(
f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
" process."
)
sample_size = int(sample_size)
dtype = next(iter(self.unet.parameters())).dtype
shape = (batch_size, self.unet.in_channels, sample_size)
shape = (batch_size, self.unet.config.in_channels, sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"

View File

@@ -79,10 +79,15 @@ class DDIMPipeline(DiffusionPipeline):
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(

View File

@@ -67,10 +67,15 @@ class DDPMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps

View File

@@ -135,7 +135,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"

View File

@@ -112,7 +112,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
height, width = image.shape[-2:]
# in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)

View File

@@ -73,7 +73,7 @@ class LDMPipeline(DiffusionPipeline):
"""
latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
)
latents = latents.to(self.device)

View File

@@ -506,6 +506,21 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
def __setattr__(self, name: str, value: Any):
if hasattr(self, name) and hasattr(self.config, name):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
else:
class_library_tuple = (None, None)
self.register_to_config(**{name: class_library_tuple})
else:
self.register_to_config(**{name: value})
super().__setattr__(name, value)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -619,9 +634,11 @@ class DiffusionPipeline(ConfigMixin):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names.keys():
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype)
@@ -646,8 +663,10 @@ class DiffusionPipeline(ConfigMixin):
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
@@ -1420,6 +1439,8 @@ class DiffusionPipeline(ConfigMixin):
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
@@ -1451,6 +1472,8 @@ class DiffusionPipeline(ConfigMixin):
def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):

View File

@@ -77,7 +77,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
device=self.device,
)

View File

@@ -476,7 +476,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -247,7 +247,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)

View File

@@ -283,7 +283,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)

View File

@@ -268,7 +268,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)

View File

@@ -649,7 +649,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -855,7 +855,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -910,7 +910,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -358,7 +358,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -561,7 +561,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
sigmas = sigmas.to(prompt_embeds.dtype)
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -722,7 +722,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -586,7 +586,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -929,7 +929,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 5. Generate the inverted noise from the input image or any other image
# generated from the input prompt.
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -595,7 +595,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -701,7 +701,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim
h = self.unet.config.attention_head_dim
if isinstance(h, list):
h = h[-1]

View File

@@ -877,7 +877,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 11. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
latents = self.prepare_latents(
shape=shape,

View File

@@ -772,7 +772,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size,
num_channels_latents=num_channels_latents,

View File

@@ -623,7 +623,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -606,7 +606,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -12,7 +12,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging
from ...utils import deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -504,6 +504,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
(
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
" `unet.config.in_channels` instead"
),
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""

View File

@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@@ -167,6 +167,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type
@property
def num_train_timesteps(self):
deprecate(
"num_train_timesteps",
"1.0.0",
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
standard_warn=False,
)
return self.config.num_train_timesteps
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the

View File

@@ -183,7 +183,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)

View File

@@ -193,7 +193,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)

View File

@@ -190,8 +190,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
steps = num_inference_steps
order = self.solver_order
if self.lower_order_final:
order = self.config.solver_order
if self.config.lower_order_final:
if order == 3:
if steps % 3 == 0:
orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1]
@@ -227,7 +227,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)

View File

@@ -195,7 +195,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)

View File

@@ -73,7 +73,7 @@ class CustomLocalPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
)
image = image.to(self.device)

View File

@@ -73,7 +73,7 @@ class CustomLocalPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
)
image = image.to(self.device)

View File

@@ -116,7 +116,7 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = model.in_channels
num_features = model.config.in_channels
seq_len = 16
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
@@ -264,7 +264,7 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = value_function.in_channels
num_features = value_function.config.in_channels
seq_len = 14
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1

View File

@@ -675,6 +675,25 @@ class CustomPipelineTests(unittest.TestCase):
image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
assert image.shape == (512, 512, 3)
def test_save_pipeline_change_config(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe = DiffusionPipeline.from_pretrained(tmpdirname)
assert pipe.scheduler.__class__.__name__ == "PNDMScheduler"
# let's make sure that changing the scheduler is correctly reflected
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.save_pretrained(tmpdirname)
pipe = DiffusionPipeline.from_pretrained(tmpdirname)
assert pipe.scheduler.__class__.__name__ == "DPMSolverMultistepScheduler"
class PipelineFastTests(unittest.TestCase):
def tearDown(self):