mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix config prints and save, load of pipelines (#2849)
* [Config] Fix config prints and save, load * Only use potential nn.Modules for dtype and device * Correct vae image processor * make sure in_channels is not accessed directly * make sure in channels is only accessed via config * Make sure schedulers only access config attributes * Make sure to access config in SAG * Fix vae processor and make style * add tests * uP * make style * Fix more naming issues * Final fix with vae config * change more
This commit is contained in:
committed by
GitHub
parent
8369196703
commit
8b451eb63b
@@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
|
||||
|
||||
... # Sample a random timestep for each image
|
||||
... timesteps = torch.randint(
|
||||
... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
|
||||
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
|
||||
... ).long()
|
||||
|
||||
... # Add noise to the clean images according to the noise magnitude at each timestep
|
||||
|
||||
@@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
@@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
|
||||
@@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
)
|
||||
|
||||
image = image.to(self.device)
|
||||
|
||||
|
||||
@@ -238,7 +238,7 @@ class BitDiffusion(DiffusionPipeline):
|
||||
**kwargs,
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, height, width),
|
||||
(batch_size, self.unet.config.in_channels, height, width),
|
||||
generator=generator,
|
||||
)
|
||||
latents = decimal_to_bits(latents) * self.bit_scale
|
||||
|
||||
@@ -254,7 +254,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -414,7 +414,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -513,7 +513,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -424,7 +424,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
|
||||
@@ -320,7 +320,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
@@ -416,7 +416,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
|
||||
"""Takes in random seed and returns corresponding noise vector"""
|
||||
return torch.randn(
|
||||
(1, self.unet.in_channels, height // 8, width // 8),
|
||||
(1, self.unet.config.in_channels, height // 8, width // 8),
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||||
device=self.device,
|
||||
dtype=dtype,
|
||||
|
||||
@@ -627,7 +627,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
if image is None:
|
||||
shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
@@ -486,7 +486,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
self.unet_in_channels = 4
|
||||
self.unet.config.in_channels = 4
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
def _encode_prompt(
|
||||
@@ -621,7 +621,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
if image is None:
|
||||
shape = (
|
||||
batch_size,
|
||||
self.unet_in_channels,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class MagicMixPipeline(DiffusionPipeline):
|
||||
|
||||
torch.manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
(1, self.unet.in_channels, height // 8, width // 8),
|
||||
(1, self.unet.config.in_channels, height // 8, width // 8),
|
||||
).to(self.device)
|
||||
|
||||
latents = self.scheduler.add_noise(
|
||||
|
||||
@@ -355,7 +355,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -433,7 +433,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
sigmas = sigmas.to(text_embeddings.dtype)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -262,8 +262,8 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -190,7 +190,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -337,7 +337,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
|
||||
@@ -794,7 +794,7 @@ def main():
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
@@ -794,7 +794,7 @@ def main():
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
@@ -641,7 +641,7 @@ def main():
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
@@ -804,7 +804,7 @@ def main():
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
@@ -707,7 +707,7 @@ def main():
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
@@ -109,13 +109,6 @@ class ConfigMixin:
|
||||
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
||||
# or solve in a more general way.
|
||||
kwargs.pop("kwargs", None)
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
if not hasattr(self, "_internal_dict"):
|
||||
internal_dict = kwargs
|
||||
else:
|
||||
|
||||
@@ -99,8 +99,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
|
||||
"""
|
||||
w, h = images.size
|
||||
w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
|
||||
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
|
||||
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
|
||||
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
return images
|
||||
|
||||
def preprocess(
|
||||
@@ -119,7 +119,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if self.do_resize:
|
||||
if self.config.do_resize:
|
||||
image = [self.resize(i) for i in image]
|
||||
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
||||
image = np.stack(image, axis=0) # to np
|
||||
@@ -129,23 +129,27 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
image = self.numpy_to_pt(image)
|
||||
_, _, height, width = image.shape
|
||||
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
|
||||
if self.config.do_resize and (
|
||||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
|
||||
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
|
||||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||
)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
_, _, height, width = image.shape
|
||||
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
|
||||
if self.config.do_resize and (
|
||||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
|
||||
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
|
||||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||
)
|
||||
|
||||
# expected range [0,1], normalize to [-1,1]
|
||||
do_normalize = self.do_normalize
|
||||
do_normalize = self.config.do_normalize
|
||||
if image.min() < 0:
|
||||
warnings.warn(
|
||||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from ..utils import BaseOutput, apply_forward_hook, deprecate
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -120,9 +120,19 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
def block_out_channels(self):
|
||||
deprecate(
|
||||
"block_out_channels",
|
||||
"1.0.0",
|
||||
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.block_out_channels
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
@@ -190,6 +190,16 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -215,6 +215,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -412,6 +412,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
|
||||
@@ -646,7 +646,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -121,17 +121,17 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(steps)
|
||||
step_generator = step_generator or generator
|
||||
# For backwards compatibility
|
||||
if type(self.unet.sample_size) == int:
|
||||
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
|
||||
if type(self.unet.config.sample_size) == int:
|
||||
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
input_dims = self.get_input_dims()
|
||||
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
||||
if noise is None:
|
||||
noise = randn_tensor(
|
||||
(
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.sample_size[0],
|
||||
self.unet.sample_size[1],
|
||||
self.unet.config.in_channels,
|
||||
self.unet.config.sample_size[0],
|
||||
self.unet.config.sample_size[1],
|
||||
),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
@@ -158,7 +158,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
|
||||
|
||||
pixels_per_second = (
|
||||
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
|
||||
self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
|
||||
)
|
||||
mask_start = int(mask_start_secs * pixels_per_second)
|
||||
mask_end = int(mask_end_secs * pixels_per_second)
|
||||
|
||||
@@ -540,7 +540,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -61,7 +61,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
to make generation deterministic.
|
||||
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
|
||||
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
|
||||
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
|
||||
`sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -73,27 +73,29 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
|
||||
|
||||
sample_size = audio_length_in_s * self.unet.sample_rate
|
||||
sample_size = audio_length_in_s * self.unet.config.sample_rate
|
||||
|
||||
down_scale_factor = 2 ** len(self.unet.up_blocks)
|
||||
if sample_size < 3 * down_scale_factor:
|
||||
raise ValueError(
|
||||
f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
|
||||
f" {3 * down_scale_factor / self.unet.sample_rate}."
|
||||
f" {3 * down_scale_factor / self.unet.config.sample_rate}."
|
||||
)
|
||||
|
||||
original_sample_size = int(sample_size)
|
||||
if sample_size % down_scale_factor != 0:
|
||||
sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
|
||||
sample_size = (
|
||||
(audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
|
||||
) * down_scale_factor
|
||||
logger.info(
|
||||
f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
|
||||
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
|
||||
f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
|
||||
f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
|
||||
" process."
|
||||
)
|
||||
sample_size = int(sample_size)
|
||||
|
||||
dtype = next(iter(self.unet.parameters())).dtype
|
||||
shape = (batch_size, self.unet.in_channels, sample_size)
|
||||
shape = (batch_size, self.unet.config.in_channels, sample_size)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
|
||||
@@ -79,10 +79,15 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
"""
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
image_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
self.unet.config.sample_size,
|
||||
self.unet.config.sample_size,
|
||||
)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
|
||||
@@ -67,10 +67,15 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# Sample gaussian noise to begin loop
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
image_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
self.unet.config.sample_size,
|
||||
self.unet.config.sample_size,
|
||||
)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
|
||||
@@ -135,7 +135,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
|
||||
@@ -112,7 +112,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# in_channels should be 6: 3 for latents, 3 for low resolution image
|
||||
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
|
||||
latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
|
||||
latents_dtype = next(self.unet.parameters()).dtype
|
||||
|
||||
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
|
||||
@@ -73,7 +73,7 @@ class LDMPipeline(DiffusionPipeline):
|
||||
"""
|
||||
|
||||
latents = randn_tensor(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
@@ -506,6 +506,21 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
if hasattr(self, name) and hasattr(self.config, name):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if self.config[name][0] is not None:
|
||||
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
|
||||
else:
|
||||
class_library_tuple = (None, None)
|
||||
|
||||
self.register_to_config(**{name: class_library_tuple})
|
||||
else:
|
||||
self.register_to_config(**{name: value})
|
||||
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -619,9 +634,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names.keys():
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module.to(torch_device, torch_dtype)
|
||||
@@ -646,8 +663,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
return module.device
|
||||
@@ -1420,6 +1439,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
@@ -1451,6 +1472,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
def set_attention_slice(self, slice_size: Optional[int]):
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||
|
||||
@@ -77,7 +77,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = randn_tensor(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -476,7 +476,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -247,7 +247,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
@@ -283,7 +283,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
@@ -268,7 +268,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
@@ -649,7 +649,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -855,7 +855,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -910,7 +910,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -358,7 +358,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -561,7 +561,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
sigmas = sigmas.to(prompt_embeds.dtype)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -722,7 +722,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -586,7 +586,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -929,7 +929,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Generate the inverted noise from the input image or any other image
|
||||
# generated from the input prompt.
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -595,7 +595,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -701,7 +701,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
|
||||
bh, hw1, hw2 = attn_map.shape
|
||||
b, latent_channel, latent_h, latent_w = original_latents.shape
|
||||
h = self.unet.attention_head_dim
|
||||
h = self.unet.config.attention_head_dim
|
||||
if isinstance(h, list):
|
||||
h = h[-1]
|
||||
|
||||
|
||||
@@ -877,7 +877,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 11. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
latents = self.prepare_latents(
|
||||
shape=shape,
|
||||
|
||||
@@ -772,7 +772,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels_latents=num_channels_latents,
|
||||
|
||||
@@ -623,7 +623,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -606,7 +606,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -12,7 +12,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
from ...utils import deprecate, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -504,6 +504,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
(
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
|
||||
" `unet.config.in_channels` instead"
|
||||
),
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, randn_tensor
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@@ -167,6 +167,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
@property
|
||||
def num_train_timesteps(self):
|
||||
deprecate(
|
||||
"num_train_timesteps",
|
||||
"1.0.0",
|
||||
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
|
||||
@@ -183,7 +183,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
|
||||
@@ -193,7 +193,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
|
||||
@@ -190,8 +190,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
steps = num_inference_steps
|
||||
order = self.solver_order
|
||||
if self.lower_order_final:
|
||||
order = self.config.solver_order
|
||||
if self.config.lower_order_final:
|
||||
if order == 3:
|
||||
if steps % 3 == 0:
|
||||
orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1]
|
||||
@@ -227,7 +227,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
|
||||
@@ -195,7 +195,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
|
||||
2
tests/fixtures/custom_pipeline/pipeline.py
vendored
2
tests/fixtures/custom_pipeline/pipeline.py
vendored
@@ -73,7 +73,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
2
tests/fixtures/custom_pipeline/what_ever.py
vendored
2
tests/fixtures/custom_pipeline/what_ever.py
vendored
@@ -73,7 +73,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
@@ -116,7 +116,7 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.in_channels
|
||||
num_features = model.config.in_channels
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, seq_len, num_features)).permute(
|
||||
0, 2, 1
|
||||
@@ -264,7 +264,7 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = value_function.in_channels
|
||||
num_features = value_function.config.in_channels
|
||||
seq_len = 14
|
||||
noise = torch.randn((1, seq_len, num_features)).permute(
|
||||
0, 2, 1
|
||||
|
||||
@@ -675,6 +675,25 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
def test_save_pipeline_change_config(self):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert pipe.scheduler.__class__.__name__ == "PNDMScheduler"
|
||||
|
||||
# let's make sure that changing the scheduler is correctly reflected
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert pipe.scheduler.__class__.__name__ == "DPMSolverMultistepScheduler"
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
|
||||
Reference in New Issue
Block a user