mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
ddpm custom timesteps (#3007)
add custom timesteps test add custom timesteps descending order check docs timesteps -> custom_timesteps can only pass one of num_inference_steps and timesteps
This commit is contained in:
@@ -162,6 +162,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.custom_timesteps = False
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
@@ -191,31 +192,62 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
num_inference_steps (`Optional[int]`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps are moved to.
|
||||
custom_timesteps (`List[int]`, optional):
|
||||
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
"""
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
if timesteps is not None:
|
||||
for i in range(1, len(timesteps)):
|
||||
if timesteps[i] >= timesteps[i - 1]:
|
||||
raise ValueError("`custom_timesteps` must be in descending order.")
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
self.custom_timesteps = True
|
||||
else:
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.custom_timesteps = False
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
prev_t = t - self.config.num_train_timesteps // num_inference_steps
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
||||
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
||||
@@ -314,8 +346,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
t = timestep
|
||||
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
@@ -428,3 +460,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
return prev_t
|
||||
|
||||
@@ -129,3 +129,59 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_sum.item() - 202.0296) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2631) < 1e-3
|
||||
|
||||
def test_custom_timesteps(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timesteps = [100, 87, 50, 1, 0]
|
||||
|
||||
scheduler.set_timesteps(timesteps=timesteps)
|
||||
|
||||
scheduler_timesteps = scheduler.timesteps
|
||||
|
||||
for i, timestep in enumerate(scheduler_timesteps):
|
||||
if i == len(timesteps) - 1:
|
||||
expected_prev_t = -1
|
||||
else:
|
||||
expected_prev_t = timesteps[i + 1]
|
||||
|
||||
prev_t = scheduler.previous_timestep(timestep)
|
||||
prev_t = prev_t.item()
|
||||
|
||||
self.assertEqual(prev_t, expected_prev_t)
|
||||
|
||||
def test_custom_timesteps_increasing_order(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timesteps = [100, 87, 50, 51, 0]
|
||||
|
||||
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
|
||||
scheduler.set_timesteps(timesteps=timesteps)
|
||||
|
||||
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timesteps = [100, 87, 50, 1, 0]
|
||||
num_inference_steps = len(timesteps)
|
||||
|
||||
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
|
||||
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
|
||||
|
||||
def test_custom_timesteps_too_large(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timesteps = [scheduler.config.num_train_timesteps]
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
|
||||
):
|
||||
scheduler.set_timesteps(timesteps=timesteps)
|
||||
|
||||
Reference in New Issue
Block a user