mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix tests
This commit is contained in:
@@ -129,6 +129,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`Literal["epsilon", "sample", "v"]`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
|
||||
"""
|
||||
|
||||
@@ -181,8 +185,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0]
|
||||
if set_alpha_to_one:
|
||||
self.final_alpha_cumprod = torch.tensor(1.0)
|
||||
self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0
|
||||
else:
|
||||
self.final_alpha_cumprod = self.alphas_cumprod[0]
|
||||
self.final_sigma = self.sigmas[0] if prediction_type == "v" else None
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -114,6 +114,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction.
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
@@ -136,6 +138,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
|
||||
predict_epsilon: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
@@ -265,8 +268,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.variance_type == "v_diffusion":
|
||||
assert self.prediction_type == "v", "Need to use v prediction with v_diffusion"
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
@@ -293,11 +296,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample * self.sqrt_alphas_cumprod[timestep]
|
||||
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
|
||||
)
|
||||
|
||||
# not check on predict_epsilon for depreciation flag above
|
||||
elif self.prediction_type == "sample" or not self.config.predict_epsilon:
|
||||
pred_original_sample = model_output
|
||||
|
||||
elif self.prediction_type == "epsilon" or self.config.predict_epsilon:
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
|
||||
elif self.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"
|
||||
|
||||
@@ -393,9 +393,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_predict_epsilon(self):
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "v"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
@@ -407,7 +407,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
time_step = 4
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
|
||||
scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config)
|
||||
|
||||
kwargs = {}
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
|
||||
Reference in New Issue
Block a user