1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove _yiyi_sigma_to_t

This commit is contained in:
yiyixuxu
2023-06-23 02:54:41 +00:00
parent 68ef317a01
commit 6ec68eec40
2 changed files with 23 additions and 14 deletions

View File

@@ -441,9 +441,9 @@ class ShapEPipeline(DiffusionPipeline):
sample=latents,
step_index=i,
).prev_sample
# YiYi testing only: I don't think we need to return latent for this pipeline
if output_type == 'latent':
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
# project the the paramters from the generated latents

View File

@@ -25,7 +25,7 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine", # cosine, exp
alpha_transform_type="cosine", # cosine, exp
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -44,11 +44,17 @@ def betas_for_alpha_bar(
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
alpha_bar_fn = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == 'exp':
alpha_bar_fn = lambda t: math.exp(t * -12.0)
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_tranform_type}")
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
@@ -111,9 +117,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type='cosine')
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
elif beta_schedule == "exp":
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type='exp')
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
@@ -181,12 +187,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas = self.use_karras_sigmas
if sigma_min is not None and sigma_max is not None:
if use_karras_sigmas is not None:
sigmas = torch.tensor([sigma_max, sigma_min])
log_sigmas = None
else:
raise ValueError(f"`sigma_min` and `sigma_max` arguments are only supported when `use_karras_sigma` is not None")
raise ValueError(
"`sigma_min` and `sigma_max` arguments are only supported when `use_karras_sigma` is not None"
)
else:
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
@@ -220,7 +227,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.dt = None
def _sigma_to_t(self, sigma, log_sigmas):
# perform interpolation on sigmas if log_sigmas is not None
if log_sigmas is not None:
# get log sigma
@@ -255,12 +261,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
t = len(self.alphas_cumprod) - 1
else:
t = np.interp(alpha_cumprod, self.alphas_cumprod.numpy()[::-1].copy(), np.arange(0, len(self.alphas_cumprod))[::-1])
t = np.interp(
alpha_cumprod,
self.alphas_cumprod.numpy()[::-1].copy(),
np.arange(0, len(self.alphas_cumprod))[::-1],
)
t = int(t)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""