mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix: DDPMScheduler.set_timesteps() (#1912)
This commit is contained in:
@@ -201,6 +201,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
|
||||
@@ -184,11 +184,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1].copy()
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
|
||||
Reference in New Issue
Block a user