1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

ddpm custom timesteps (#3007)

add custom timesteps test

add custom timesteps descending order check

docs

timesteps -> custom_timesteps

can only pass one of num_inference_steps and timesteps
This commit is contained in:
Will Berman
2023-04-14 12:39:38 -07:00
committed by Daniel Gu
parent fa6a6b470d
commit a256f8446d
2 changed files with 119 additions and 16 deletions

View File

@@ -162,6 +162,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# setable values
self.custom_timesteps = False
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
@@ -191,31 +192,62 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
num_inference_steps (`Optional[int]`):
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
`timesteps` must be `None`.
device (`str` or `torch.device`, optional):
the device to which the timesteps are moved to.
custom_timesteps (`List[int]`, optional):
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
must be `None`.
"""
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if timesteps is not None:
for i in range(1, len(timesteps)):
if timesteps[i] >= timesteps[i - 1]:
raise ValueError("`custom_timesteps` must be in descending order.")
self.num_inference_steps = num_inference_steps
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
self.custom_timesteps = True
else:
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
prev_t = t - self.config.num_train_timesteps // num_inference_steps
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +346,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
t = timestep
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
@@ -428,3 +460,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t

View File

@@ -129,3 +129,59 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 202.0296) < 1e-2
assert abs(result_mean.item() - 0.2631) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)