1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Easily understandable error if inference steps not set before using scheduler (#263) (#264)

* Helpful exception if inference steps not set in schedulers (#263)

* Apply suggestions from codereview by patrickvonplaten

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Richard Löwenström
2022-08-30 19:47:24 +02:00
committed by GitHub
parent 76985bc87a
commit 170af08e7f
4 changed files with 30 additions and 0 deletions

View File

@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False,
generator=None,
):
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding

View File

@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4]
@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "

View File

@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_seed(seed)
# TODO(Patrick) non-PyTorch
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * torch.ones(
sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0])
@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
if seed is not None:
self.set_seed(seed)
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(sample)

View File

@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t):
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# TODO(Patrick) better comments + non-PyTorch
# postprocess model score
log_mean_coeff = (