1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Pytorch] Pytorch only schedulers (#534)

* pytorch only schedulers

* fix style

* remove match_shape

* pytorch only ddpm

* remove SchedulerMixin

* remove numpy from karras_ve

* fix types

* remove numpy from lms_discrete

* remove numpy from pndm

* fix typo

* remove mixin and numpy from sde_vp and ve

* remove remaining tensor_format

* fix style

* sigmas has to be torch tensor

* removed set_format in readme

* remove set format from docs

* remove set_format from pipelines

* update tests

* fix typo

* continue to use mixin

* fix imports

* removed unsed imports

* match shape instead of assuming image shapes

* remove import typo

* update call to add_noise

* use math instead of numpy

* fix t_index

* removed commented out numpy tests

* timesteps needs to be discrete

* cast timesteps to int in flax scheduler too

* fix device mismatch issue

* small fix

* Update src/diffusers/schedulers/scheduling_pndm.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Kashif Rasul
2022-09-27 15:27:34 +02:00
committed by GitHub
parent 3b747de845
commit bd8df2da89
27 changed files with 231 additions and 464 deletions

View File

@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework-specific.
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.

View File

@@ -274,7 +274,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# # predict the noise residual
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance

View File

@@ -424,7 +424,10 @@ def main():
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset(

View File

@@ -59,7 +59,7 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,

View File

@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -45,7 +45,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -57,7 +57,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(

View File

@@ -69,7 +69,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(

View File

@@ -83,7 +83,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
@@ -320,11 +319,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))

View File

@@ -35,7 +35,6 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
text_encoder=text_encoder,

View File

@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()

View File

@@ -8,8 +8,7 @@
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework specific.
## Examples

View File

@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -72,7 +72,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas)
class DDIMScheduler(SchedulerMixin, ConfigMixin):
@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -121,15 +120,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -137,20 +137,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -186,15 +183,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
@@ -205,9 +201,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
@@ -251,7 +247,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
@@ -273,9 +269,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise = torch.randn(model_output.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(model_output):
variance = variance.numpy()
prev_sample = prev_sample + variance
if not return_dict:
@@ -285,16 +278,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -70,7 +70,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas, dtype=torch.float32)
class DDPMScheduler(SchedulerMixin, ConfigMixin):
@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -113,15 +112,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -129,15 +129,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.variance_type = variance_type
@@ -153,8 +150,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
)[::-1]
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
@@ -170,15 +166,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
variance = torch.clamp(variance, min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
variance = torch.log(torch.clamp(variance, min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = self.log(self.betas[t])
variance = torch.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
@@ -191,9 +187,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
predict_epsilon=True,
generator=None,
return_dict: bool = True,
@@ -203,9 +199,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
@@ -240,7 +236,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
@@ -254,7 +250,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
noise = self.randn_like(model_output, generator=generator)
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
@@ -266,16 +264,21 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -87,15 +86,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
):
# setable values
self.num_inference_steps = None
self.timesteps = None
self.schedule = None # sigma(t_i)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
self.schedule: torch.FloatTensor = None # sigma(t_i)
def set_timesteps(self, num_inference_steps: int):
"""
@@ -108,20 +103,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
self.schedule = np.array(self.schedule, dtype=np.float32)
self.set_format(tensor_format=self.tensor_format)
self.schedule = torch.tensor(schedule, dtype=torch.float32)
def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[torch.FloatTensor, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
@@ -142,10 +135,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_hat: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
@@ -153,10 +146,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
@@ -180,24 +173,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
sample_hat: torch.FloatTensor,
sample_prev: torch.FloatTensor,
derivative: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.FloatTensor`): TODO
sample_prev (`torch.FloatTensor`): TODO
derivative (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns:

View File

@@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -75,31 +74,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
self.derivatives = []
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def get_lms_coefficient(self, order, t, current_order):
"""
Compute a linear multistep coefficient.
@@ -131,24 +128,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
low_idx = np.floor(timesteps).astype(int)
high_idx = np.ceil(timesteps).astype(int)
frac = np.mod(timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = timesteps.astype(int)
self.derivatives = []
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
order: int = 4,
return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
@@ -157,9 +154,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
@@ -197,15 +194,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.sigmas.device)
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sigma = sigmas[timesteps].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):

View File

@@ -132,7 +132,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
timesteps=timesteps.astype(int),
derivatives=jnp.array([]),
sigmas=sigmas,
)

View File

@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas, dtype=torch.float32)
class PNDMScheduler(SchedulerMixin, ConfigMixin):
@@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
"""
@@ -101,15 +100,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -117,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
@@ -139,9 +139,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None
self.timesteps = None
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -189,13 +186,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.ets = []
self.counter = 0
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -205,9 +201,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -224,9 +220,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_prk(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -234,9 +230,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -279,9 +275,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_plms(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -289,9 +285,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
times to approximate the solution.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -381,16 +377,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -14,11 +14,11 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
@@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
"""
@register_to_config
@@ -77,16 +76,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
tensor_format: str = "pt",
):
# setable values
self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -98,13 +93,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
@@ -129,28 +119,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
if self.timesteps is None:
self.set_timesteps(num_inference_steps, sampling_eps)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
def get_adjacent_sigma(self, timesteps, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
return torch.where(
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
def set_seed(self, seed):
warnings.warn(
@@ -158,19 +136,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
" generator instead.",
DeprecationWarning,
)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
elif tensor_format == "pt":
torch.manual_seed(seed)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
torch.manual_seed(seed)
def step_pred(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
@@ -180,9 +152,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -210,18 +182,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample)
drift = torch.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion.unsqueeze(-1)
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample, generator=generator)
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean)
@@ -230,8 +205,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
@@ -241,8 +216,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
after making the prediction for the previous timestep.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sample (`torch.FloatTensor` or `np.ndarray`):
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -262,18 +237,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(sample, generator=generator)
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
# compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output)
noise_norm = self.norm(noise)
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size.unsqueeze(-1)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict:
return (prev_sample,)

View File

@@ -16,7 +16,8 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import numpy as np
import math
import torch
from ..configuration_utils import ConfigMixin, register_to_config
@@ -39,7 +40,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
@@ -47,7 +48,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t):
def step_pred(self, score, x, t, generator=None):
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
@@ -59,20 +60,27 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
)
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
score = -score / std[:, None, None, None]
std = std.flatten()
while len(std.shape) < len(score.shape):
std = std.unsqueeze(-1)
score = -score / std
# compute
dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
beta_t = beta_t.flatten()
while len(beta_t.shape) < len(x.shape):
beta_t = beta_t.unsqueeze(-1)
drift = -0.5 * beta_t * x
diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * score
drift = drift - diffusion**2 * score
x_mean = x + drift * dt
# add noise
noise = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
x = x_mean + diffusion * math.sqrt(-dt) * noise
return x, x_mean

View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from ..utils import BaseOutput
@@ -43,83 +41,3 @@ class SchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["tensor_format"]
def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format
if tensor_format == "pt":
for key, value in vars(self).items():
if isinstance(value, np.ndarray):
setattr(self, key, torch.from_numpy(value))
return self
def clip(self, tensor, min_value=None, max_value=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.clip(tensor, min_value, max_value)
elif tensor_format == "pt":
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def log(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.log(tensor)
elif tensor_format == "pt":
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
values: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values
def norm(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.linalg.norm(tensor)
elif tensor_format == "pt":
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def randn_like(self, tensor, generator=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.random.randn(*np.shape(tensor))
elif tensor_format == "pt":
# return torch.randn_like(tensor)
return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.zeros_like(tensor)
elif tensor_format == "pt":
return torch.zeros_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

View File

@@ -191,7 +191,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ddim(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt")
scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
@@ -220,7 +220,7 @@ class PipelineFastTests(unittest.TestCase):
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
scheduler = PNDMScheduler(tensor_format="pt")
scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
@@ -242,7 +242,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ldm_text2img(self):
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(tensor_format="pt")
scheduler = DDIMScheduler()
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -339,7 +339,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -460,7 +460,7 @@ class PipelineFastTests(unittest.TestCase):
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler(tensor_format="pt")
scheduler = ScoreSdeVeScheduler()
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
sde_ve.to(torch_device)
@@ -484,7 +484,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ldm_uncond(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt")
scheduler = DDIMScheduler()
vae = self.dummy_vq_model
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
@@ -512,7 +512,7 @@ class PipelineFastTests(unittest.TestCase):
def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = KarrasVeScheduler(tensor_format="pt")
scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device)
@@ -535,7 +535,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -646,7 +646,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -842,7 +842,6 @@ class PipelineTesterMixin(unittest.TestCase):
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
@@ -882,7 +881,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt")
scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device)
@@ -902,7 +901,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt")
scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
@@ -1043,8 +1042,8 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
@@ -1067,8 +1066,8 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
@@ -1093,7 +1092,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_karras_ve_pipeline(self):
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler(tensor_format="pt")
scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device)

View File

@@ -173,34 +173,6 @@ class SchedulerCommonTest(unittest.TestCase):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
@@ -266,7 +238,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_sample": True,
"tensor_format": "pt",
}
config.update(**kwargs)
@@ -305,10 +276,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
# TODO Make DDPM Numpy compatible
def test_pytorch_equal_numpy(self):
pass
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
@@ -387,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
@@ -556,72 +523,6 @@ class PNDMSchedulerTest(SchedulerCommonTest):
return sample
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
sample = sample_pt.numpy()
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_set_format(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
for key, value in vars(scheduler).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
# check if `scheduler.set_format` does convert correctly attrs to pt format
for key, value in vars(scheduler_pt).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
assert not isinstance(
value, np.ndarray
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
@@ -667,12 +568,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10)
assert torch.equal(
assert np.equal(
scheduler.timesteps,
torch.tensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
@@ -786,7 +685,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
"sigma_min": 0.01,
"sigma_max": 1348,
"sampling_eps": 1e-5,
"tensor_format": "pt", # TODO add test for tensor formats
}
config.update(**kwargs)
@@ -936,7 +834,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
"tensor_format": "pt",
}
config.update(**kwargs)
@@ -958,28 +855,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
for t in [0, 500, 800]:
self.check_over_forward(time_step=t)
def test_pytorch_equal_numpy(self):
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler_config["tensor_format"] = "np"
scheduler = scheduler_class(**scheduler_config)
scheduler_config["tensor_format"] = "pt"
scheduler_pt = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
scheduler_pt.set_timesteps(self.num_inference_steps)
output = scheduler.step(residual, 1, sample).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
@@ -1001,5 +876,5 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_sum.item() - 1006.370) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3

View File

@@ -41,7 +41,6 @@ class TrainingTests(unittest.TestCase):
beta_end=0.02,
beta_schedule="linear",
clip_sample=True,
tensor_format="pt",
)
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
@@ -49,7 +48,6 @@ class TrainingTests(unittest.TestCase):
beta_end=0.02,
beta_schedule="linear",
clip_sample=True,
tensor_format="pt",
)
assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps