1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix: DDPMScheduler.set_timesteps() (#1912)

This commit is contained in:
Joqsan
2023-01-04 15:02:50 +03:00
committed by GitHub
parent d67c305120
commit 675ef1ffbd
2 changed files with 19 additions and 4 deletions

View File

@@ -201,6 +201,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio

View File

@@ -184,11 +184,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):