mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
support v prediction in other schedulers (#1505)
* support v prediction in other schedulers * v heun * add tests for v pred * fix tests * fix test euler a * v ddpm
This commit is contained in:
@@ -280,10 +280,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the DDPMScheduler."
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||||
" `v_prediction` for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
|
||||
@@ -78,6 +78,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -202,7 +203,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
sigma_from = self.sigmas[step_index]
|
||||
sigma_to = self.sigmas[step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
|
||||
@@ -54,6 +54,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -184,7 +185,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 2. Convert to an ODE derivative
|
||||
|
||||
@@ -78,6 +78,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -215,7 +216,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
@@ -102,6 +102,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
prediction_type: str = "epsilon",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -368,6 +369,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
elif self.config.prediction_type != "epsilon":
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
|
||||
)
|
||||
|
||||
# corresponds to (α_(t−δ) - α_t) divided by
|
||||
# denominator of x_t in formula (9) and plus 1
|
||||
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
|
||||
|
||||
@@ -635,7 +635,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
@@ -711,6 +711,37 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 258.9070) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3374) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_trained_timesteps = len(scheduler)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
for t in reversed(range(num_trained_timesteps)):
|
||||
# 1. predict noise residual
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
|
||||
|
||||
# if t > 0:
|
||||
# noise = self.dummy_sample_deter
|
||||
# variance = scheduler.get_variance(t) ** (0.5) * noise
|
||||
#
|
||||
# sample = pred_prev_sample + variance
|
||||
sample = pred_prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 201.9864) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2630) < 1e-3
|
||||
|
||||
|
||||
class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDIMScheduler,)
|
||||
@@ -768,6 +799,10 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
@@ -805,6 +840,15 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 172.0067) < 1e-2
|
||||
assert abs(result_mean.item() - 0.223967) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 52.5302) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0684) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
@@ -971,6 +1015,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
solver_type=solver_type,
|
||||
)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
@@ -1004,6 +1052,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_mean.item() - 0.3301) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2251) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
@@ -1184,6 +1238,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
@@ -1225,6 +1283,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 198.1318) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2580) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 67.3986) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0878) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
@@ -1453,6 +1519,10 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 500, 800]:
|
||||
self.check_over_forward(time_step=t)
|
||||
@@ -1481,6 +1551,30 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 1006.388) < 1e-2
|
||||
assert abs(result_mean.item() - 1.31) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 0.0017) < 1e-2
|
||||
assert abs(result_mean.item() - 2.2676e-06) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1534,6 +1628,10 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1565,6 +1663,37 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 10.0807) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0131) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 0.0002) < 1e-2
|
||||
assert abs(result_mean.item() - 2.2676e-06) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1624,6 +1753,10 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1660,6 +1793,42 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 144.8084) < 1e-2
|
||||
assert abs(result_mean.item() - 0.18855) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 108.4439) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1412) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 102.5807) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1335) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1932,6 +2101,10 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
@@ -1962,6 +2135,36 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 4.6934e-07) < 1e-2
|
||||
assert abs(result_mean.item() - 6.1112e-10) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
Reference in New Issue
Block a user