1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix tests

This commit is contained in:
Nathan Lambert
2022-11-17 14:43:14 -08:00
parent 11362ae5d2
commit e39198306b
3 changed files with 24 additions and 10 deletions

View File

@@ -129,6 +129,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`Literal["epsilon", "sample", "v"]`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
@@ -181,8 +185,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0]
if set_alpha_to_one:
self.final_alpha_cumprod = torch.tensor(1.0)
self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0
else:
self.final_alpha_cumprod = self.alphas_cumprod[0]
self.final_sigma = self.sigmas[0] if prediction_type == "v" else None
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

View File

@@ -114,6 +114,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
https://imagen.research.google/video/paper.pdf)
predict_epsilon (`bool`, default `True`):
depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction.
"""
_compatible_classes = [
@@ -136,6 +138,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
predict_epsilon: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@@ -265,8 +268,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if self.variance_type == "v_diffusion":
assert self.prediction_type == "v", "Need to use v prediction with v_diffusion"
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
@@ -293,11 +296,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample * self.sqrt_alphas_cumprod[timestep]
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
)
# not check on predict_epsilon for depreciation flag above
elif self.prediction_type == "sample" or not self.config.predict_epsilon:
pred_original_sample = model_output
elif self.prediction_type == "epsilon" or self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"

View File

@@ -393,9 +393,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self):
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
@@ -407,7 +407,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):