mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Helpful exception if inference steps not set in schedulers (#263) * Apply suggestions from codereview by patrickvonplaten * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
76985bc87a
commit
170af08e7f
@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
):
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
|
||||
@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
solution to the differential equation.
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
|
||||
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
|
||||
timestep = self.prk_timesteps[self.counter // 4 * 4]
|
||||
@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if not self.config.skip_prk_steps and len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
|
||||
@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_seed(seed)
|
||||
# TODO(Patrick) non-PyTorch
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
timestep = timestep * torch.ones(
|
||||
sample.shape[0], device=sample.device
|
||||
) # torch.repeat_interleave(timestep, sample.shape[0])
|
||||
@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
if seed is not None:
|
||||
self.set_seed(seed)
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
||||
# sample noise for correction
|
||||
noise = self.randn_like(sample)
|
||||
|
||||
@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
|
||||
def step_pred(self, score, x, t):
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# TODO(Patrick) better comments + non-PyTorch
|
||||
# postprocess model score
|
||||
log_mean_coeff = (
|
||||
|
||||
Reference in New Issue
Block a user