From b64c5227595f5eed10f6ff3ac7953de0bb07ab2d Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 31 Aug 2022 13:12:08 +0100 Subject: [PATCH] [PNDM Scheduler] format timesteps attrs to np arrays (#273) * format timesteps attrs to np arrays in pndm scheduler because lists don't get formatted to tensors in `self.set_format` * convert to long type to use timesteps as indices for tensors * add scheduler set_format test * fix `_timesteps` type * make style with black 22.3.0 and isort 5.10.1 Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_pndm.py | 14 +++++----- tests/test_scheduler.py | 29 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 103b9edf15..9473252061 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -103,22 +103,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) self._offset = offset - self._timesteps = [t + self._offset for t in self._timesteps] + self._timesteps = np.array([t + self._offset for t in self._timesteps]) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - self.prk_timesteps = [] - self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])) + self.prk_timesteps = np.array([]) + self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) - self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1])) - self.plms_timesteps = list(reversed(self._timesteps[:-3])) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = self.prk_timesteps + self.plms_timesteps + self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) self.ets = [] self.counter = 0 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 03cfc3aad7..0ce6715f55 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -485,6 +485,35 @@ class PNDMSchedulerTest(SchedulerCommonTest): assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + def test_set_format(self): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(tensor_format="np", **scheduler_config) + scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + scheduler_pt.set_timesteps(num_inference_steps) + + for key, value in vars(scheduler).items(): + # we only allow `ets` attr to be a list + assert not isinstance(value, list) or key in [ + "ets" + ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}" + + # check if `scheduler.set_format` does convert correctly attrs to pt format + for key, value in vars(scheduler_pt).items(): + # we only allow `ets` attr to be a list + assert not isinstance(value, list) or key in [ + "ets" + ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" + assert not isinstance( + value, np.ndarray + ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" + def test_step_shape(self): kwargs = dict(self.forward_default_kwargs)