mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pytorch] Pytorch only schedulers (#534)
* pytorch only schedulers * fix style * remove match_shape * pytorch only ddpm * remove SchedulerMixin * remove numpy from karras_ve * fix types * remove numpy from lms_discrete * remove numpy from pndm * fix typo * remove mixin and numpy from sde_vp and ve * remove remaining tensor_format * fix style * sigmas has to be torch tensor * removed set_format in readme * remove set format from docs * remove set_format from pipelines * update tests * fix typo * continue to use mixin * fix imports * removed unsed imports * match shape instead of assuming image shapes * remove import typo * update call to add_noise * use math instead of numpy * fix t_index * removed commented out numpy tests * timesteps needs to be discrete * cast timesteps to int in flax scheduler too * fix device mismatch issue * small fix * Update src/diffusers/schedulers/scheduling_pndm.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
|
||||
The core API for any new scheduler must follow a limited structure.
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
|
||||
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
|
||||
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
|
||||
with a `set_format(...)` method.
|
||||
- Schedulers should be framework-specific.
|
||||
|
||||
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# # predict the noise residual
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform classifier free guidance
|
||||
|
||||
@@ -424,7 +424,10 @@ def main():
|
||||
|
||||
# TODO (patil-suraj): load scheduler using args
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
|
||||
@@ -59,7 +59,7 @@ def main(args):
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
|
||||
@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -45,7 +45,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -57,7 +57,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
|
||||
@@ -69,7 +69,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
|
||||
@@ -83,7 +83,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
@@ -320,11 +319,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
|
||||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("np")
|
||||
self.register_modules(
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -8,8 +8,7 @@
|
||||
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
|
||||
the forward pass.
|
||||
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
|
||||
with a `set_format(...)` method.
|
||||
- Schedulers should be framework specific.
|
||||
|
||||
## Examples
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -72,7 +72,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
return torch.tensor(betas)
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@@ -121,15 +120,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -137,20 +137,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
@@ -186,15 +183,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
|
||||
self.timesteps += offset
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
@@ -205,9 +201,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
@@ -251,7 +247,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
@@ -273,9 +269,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = torch.randn(model_output.shape, generator=generator).to(device)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
if not torch.is_tensor(model_output):
|
||||
variance = variance.numpy()
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
@@ -285,16 +278,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -70,7 +70,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@@ -113,15 +112,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -129,15 +129,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
@@ -153,8 +150,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
)[::-1]
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
@@ -170,15 +166,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# hacks - were probably added for training stability
|
||||
if variance_type == "fixed_small":
|
||||
variance = self.clip(variance, min_value=1e-20)
|
||||
variance = torch.clamp(variance, min=1e-20)
|
||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
elif variance_type == "fixed_small_log":
|
||||
variance = self.log(self.clip(variance, min_value=1e-20))
|
||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||
elif variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
elif variance_type == "fixed_large_log":
|
||||
# Glide max_log
|
||||
variance = self.log(self.betas[t])
|
||||
variance = torch.log(self.betas[t])
|
||||
elif variance_type == "learned":
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
@@ -191,9 +187,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
predict_epsilon=True,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
@@ -203,9 +199,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
@@ -240,7 +236,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
@@ -254,7 +250,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = self.randn_like(model_output, generator=generator)
|
||||
noise = torch.randn(
|
||||
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
|
||||
).to(model_output.device)
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
@@ -266,16 +264,21 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
A reasonable range is [0, 10].
|
||||
s_max (`float`): the end value of the sigma range where we add noise.
|
||||
A reasonable range is [0.2, 80].
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@@ -87,15 +86,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_churn: float = 80,
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
self.schedule = None # sigma(t_i)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
self.num_inference_steps: int = None
|
||||
self.timesteps: np.ndarray = None
|
||||
self.schedule: torch.FloatTensor = None # sigma(t_i)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
"""
|
||||
@@ -108,20 +103,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
schedule = [
|
||||
(
|
||||
self.config.sigma_max**2
|
||||
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
|
||||
)
|
||||
for i in self.timesteps
|
||||
]
|
||||
self.schedule = np.array(self.schedule, dtype=np.float32)
|
||||
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
self.schedule = torch.tensor(schedule, dtype=torch.float32)
|
||||
|
||||
def add_noise_to_input(
|
||||
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
|
||||
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
|
||||
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
|
||||
) -> Tuple[torch.FloatTensor, float]:
|
||||
"""
|
||||
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
||||
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
||||
@@ -142,10 +135,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_hat: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
"""
|
||||
@@ -153,10 +146,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_hat (`torch.FloatTensor`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
|
||||
|
||||
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
|
||||
@@ -180,24 +173,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_prev: Union[torch.FloatTensor, np.ndarray],
|
||||
derivative: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_hat: torch.FloatTensor,
|
||||
sample_prev: torch.FloatTensor,
|
||||
derivative: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. TODO complete description
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_hat (`torch.FloatTensor`): TODO
|
||||
sample_prev (`torch.FloatTensor`): TODO
|
||||
derivative (`torch.FloatTensor`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
|
||||
@@ -75,31 +74,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
|
||||
self.derivatives = []
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
"""
|
||||
Compute a linear multistep coefficient.
|
||||
@@ -131,24 +128,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||
|
||||
low_idx = np.floor(self.timesteps).astype(int)
|
||||
high_idx = np.ceil(self.timesteps).astype(int)
|
||||
frac = np.mod(self.timesteps, 1.0)
|
||||
low_idx = np.floor(timesteps).astype(int)
|
||||
high_idx = np.ceil(timesteps).astype(int)
|
||||
frac = np.mod(timesteps, 1.0)
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
|
||||
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
self.timesteps = timesteps.astype(int)
|
||||
self.derivatives = []
|
||||
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
order: int = 4,
|
||||
return_dict: bool = True,
|
||||
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
|
||||
@@ -157,9 +154,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
|
||||
@@ -197,15 +194,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.sigmas.device)
|
||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||
noisy_samples = original_samples + noise * sigmas
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
sigmas = self.sigmas.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sigma = sigmas[timesteps].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -132,7 +132,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
timesteps=timesteps.astype(int),
|
||||
derivatives=jnp.array([]),
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
|
||||
|
||||
"""
|
||||
|
||||
@@ -101,15 +100,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -117,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
@@ -139,9 +139,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -189,13 +186,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.ets = []
|
||||
self.counter = 0
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -205,9 +201,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
@@ -224,9 +220,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -234,9 +230,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
solution to the differential equation.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
@@ -279,9 +275,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step_plms(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -289,9 +285,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
times to approximate the solution.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
@@ -381,16 +377,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
|
||||
epsilon.
|
||||
correct_steps (`int`): number of correction steps performed on a produced sample.
|
||||
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -77,16 +76,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_max: float = 1348.0,
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
# setable values
|
||||
self.timesteps = None
|
||||
|
||||
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -98,13 +93,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
|
||||
elif tensor_format == "pt":
|
||||
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
|
||||
|
||||
def set_sigmas(
|
||||
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
|
||||
@@ -129,28 +119,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.timesteps is None:
|
||||
self.set_timesteps(num_inference_steps, sampling_eps)
|
||||
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
|
||||
self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
|
||||
elif tensor_format == "pt":
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
|
||||
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
|
||||
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
|
||||
|
||||
def get_adjacent_sigma(self, timesteps, t):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
|
||||
elif tensor_format == "pt":
|
||||
return torch.where(
|
||||
timesteps == 0,
|
||||
torch.zeros_like(t.to(timesteps.device)),
|
||||
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
|
||||
)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
return torch.where(
|
||||
timesteps == 0,
|
||||
torch.zeros_like(t.to(timesteps.device)),
|
||||
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
|
||||
)
|
||||
|
||||
def set_seed(self, seed):
|
||||
warnings.warn(
|
||||
@@ -158,19 +136,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
" generator instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
np.random.seed(seed)
|
||||
elif tensor_format == "pt":
|
||||
torch.manual_seed(seed)
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def step_pred(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -180,9 +152,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
@@ -210,18 +182,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
sigma = self.discrete_sigmas[timesteps].to(sample.device)
|
||||
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
|
||||
drift = self.zeros_like(sample)
|
||||
drift = torch.zeros_like(sample)
|
||||
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
|
||||
|
||||
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
|
||||
# also equation 47 shows the analog from SDE models to ancestral sampling methods
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
|
||||
diffusion = diffusion.flatten()
|
||||
while len(diffusion.shape) < len(sample.shape):
|
||||
diffusion = diffusion.unsqueeze(-1)
|
||||
drift = drift - diffusion**2 * model_output
|
||||
|
||||
# equation 6: sample noise for the diffusion term of
|
||||
noise = self.randn_like(sample, generator=generator)
|
||||
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
|
||||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
||||
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, prev_sample_mean)
|
||||
@@ -230,8 +205,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
model_output: torch.FloatTensor,
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -241,8 +216,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
after making the prediction for the previous timestep.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
@@ -262,18 +237,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
||||
# sample noise for correction
|
||||
noise = self.randn_like(sample, generator=generator)
|
||||
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
|
||||
|
||||
# compute step size from the model_output, the noise, and the snr
|
||||
grad_norm = self.norm(model_output)
|
||||
noise_norm = self.norm(noise)
|
||||
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
|
||||
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
|
||||
# self.repeat_scalar(step_size, sample.shape[0])
|
||||
|
||||
# compute corrected sample: model_output term and noise term
|
||||
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
|
||||
step_size = step_size.flatten()
|
||||
while len(step_size.shape) < len(sample.shape):
|
||||
step_size = step_size.unsqueeze(-1)
|
||||
prev_sample_mean = sample + step_size * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -39,7 +40,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
self.timesteps = None
|
||||
@@ -47,7 +48,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
|
||||
def step_pred(self, score, x, t):
|
||||
def step_pred(self, score, x, t, generator=None):
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
@@ -59,20 +60,27 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
||||
)
|
||||
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
|
||||
score = -score / std[:, None, None, None]
|
||||
std = std.flatten()
|
||||
while len(std.shape) < len(score.shape):
|
||||
std = std.unsqueeze(-1)
|
||||
score = -score / std
|
||||
|
||||
# compute
|
||||
dt = -1.0 / len(self.timesteps)
|
||||
|
||||
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
|
||||
drift = -0.5 * beta_t[:, None, None, None] * x
|
||||
beta_t = beta_t.flatten()
|
||||
while len(beta_t.shape) < len(x.shape):
|
||||
beta_t = beta_t.unsqueeze(-1)
|
||||
drift = -0.5 * beta_t * x
|
||||
|
||||
diffusion = torch.sqrt(beta_t)
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * score
|
||||
drift = drift - diffusion**2 * score
|
||||
x_mean = x + drift * dt
|
||||
|
||||
# add noise
|
||||
noise = torch.randn_like(x)
|
||||
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
|
||||
noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
|
||||
x = x_mean + diffusion * math.sqrt(-dt) * noise
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
@@ -43,83 +41,3 @@ class SchedulerMixin:
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
ignore_for_config = ["tensor_format"]
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
self.tensor_format = tensor_format
|
||||
if tensor_format == "pt":
|
||||
for key, value in vars(self).items():
|
||||
if isinstance(value, np.ndarray):
|
||||
setattr(self, key, torch.from_numpy(value))
|
||||
|
||||
return self
|
||||
|
||||
def clip(self, tensor, min_value=None, max_value=None):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
|
||||
if tensor_format == "np":
|
||||
return np.clip(tensor, min_value, max_value)
|
||||
elif tensor_format == "pt":
|
||||
return torch.clamp(tensor, min_value, max_value)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def log(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
|
||||
if tensor_format == "np":
|
||||
return np.log(tensor)
|
||||
elif tensor_format == "pt":
|
||||
return torch.log(tensor)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
|
||||
"""
|
||||
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
|
||||
|
||||
Args:
|
||||
values: an array or tensor of values to extract.
|
||||
broadcast_array: an array with a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
Returns:
|
||||
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
values = values.flatten()
|
||||
|
||||
while len(values.shape) < len(broadcast_array.shape):
|
||||
values = values[..., None]
|
||||
if tensor_format == "pt":
|
||||
values = values.to(broadcast_array.device)
|
||||
|
||||
return values
|
||||
|
||||
def norm(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
return np.linalg.norm(tensor)
|
||||
elif tensor_format == "pt":
|
||||
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def randn_like(self, tensor, generator=None):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
return np.random.randn(*np.shape(tensor))
|
||||
elif tensor_format == "pt":
|
||||
# return torch.randn_like(tensor)
|
||||
return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def zeros_like(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
return np.zeros_like(tensor)
|
||||
elif tensor_format == "pt":
|
||||
return torch.zeros_like(tensor)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
@@ -191,7 +191,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_ddim(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
@@ -220,7 +220,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_pndm_cifar10(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = PNDMScheduler(tensor_format="pt")
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
@@ -242,7 +242,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_ldm_text2img(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
@@ -339,7 +339,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
def test_stable_diffusion_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
@@ -460,7 +460,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = ScoreSdeVeScheduler(tensor_format="pt")
|
||||
scheduler = ScoreSdeVeScheduler()
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
@@ -484,7 +484,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_ldm_uncond(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vq_model
|
||||
|
||||
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
|
||||
@@ -512,7 +512,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
def test_karras_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = KarrasVeScheduler(tensor_format="pt")
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
@@ -535,7 +535,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
def test_stable_diffusion_img2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
@@ -646,7 +646,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
def test_stable_diffusion_inpaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
@@ -842,7 +842,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
@@ -882,7 +881,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddim.to(torch_device)
|
||||
@@ -902,7 +901,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = PNDMScheduler(tensor_format="pt")
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
@@ -1043,8 +1042,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
ddpm_scheduler = DDPMScheduler()
|
||||
ddim_scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddpm.to(torch_device)
|
||||
@@ -1067,8 +1066,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
ddpm_scheduler = DDPMScheduler()
|
||||
ddim_scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddpm.to(torch_device)
|
||||
@@ -1093,7 +1092,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_karras_ve_pipeline(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler(tensor_format="pt")
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
|
||||
@@ -173,34 +173,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample_pt = self.dummy_sample
|
||||
residual_pt = 0.1 * sample_pt
|
||||
|
||||
sample = sample_pt.numpy()
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
@@ -266,7 +238,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
"beta_schedule": "linear",
|
||||
"variance_type": "fixed_small",
|
||||
"clip_sample": True,
|
||||
"tensor_format": "pt",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -305,10 +276,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
|
||||
|
||||
# TODO Make DDPM Numpy compatible
|
||||
def test_pytorch_equal_numpy(self):
|
||||
pass
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -387,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(5)
|
||||
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
|
||||
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
@@ -556,72 +523,6 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
return sample
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample_pt = self.dummy_sample
|
||||
residual_pt = 0.1 * sample_pt
|
||||
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
|
||||
|
||||
sample = sample_pt.numpy()
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_set_format(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
|
||||
for key, value in vars(scheduler).items():
|
||||
# we only allow `ets` attr to be a list
|
||||
assert not isinstance(value, list) or key in [
|
||||
"ets"
|
||||
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
|
||||
|
||||
# check if `scheduler.set_format` does convert correctly attrs to pt format
|
||||
for key, value in vars(scheduler_pt).items():
|
||||
# we only allow `ets` attr to be a list
|
||||
assert not isinstance(value, list) or key in [
|
||||
"ets"
|
||||
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
|
||||
assert not isinstance(
|
||||
value, np.ndarray
|
||||
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
@@ -667,12 +568,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(10)
|
||||
assert torch.equal(
|
||||
assert np.equal(
|
||||
scheduler.timesteps,
|
||||
torch.tensor(
|
||||
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
|
||||
),
|
||||
)
|
||||
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
|
||||
).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
|
||||
@@ -786,7 +685,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
"sigma_min": 0.01,
|
||||
"sigma_max": 1348,
|
||||
"sampling_eps": 1e-5,
|
||||
"tensor_format": "pt", # TODO add test for tensor formats
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -936,7 +834,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
"tensor_format": "pt",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -958,28 +855,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
for t in [0, 500, 800]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample_pt = self.dummy_sample
|
||||
residual_pt = 0.1 * sample_pt
|
||||
|
||||
sample = sample_pt.numpy()
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler_config["tensor_format"] = "np"
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler_config["tensor_format"] = "pt"
|
||||
scheduler_pt = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
scheduler_pt.set_timesteps(self.num_inference_steps)
|
||||
|
||||
output = scheduler.step(residual, 1, sample).prev_sample
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1001,5 +876,5 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 1006.388) < 1e-2
|
||||
assert abs(result_sum.item() - 1006.370) < 1e-2
|
||||
assert abs(result_mean.item() - 1.31) < 1e-3
|
||||
|
||||
@@ -41,7 +41,6 @@ class TrainingTests(unittest.TestCase):
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
@@ -49,7 +48,6 @@ class TrainingTests(unittest.TestCase):
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user