1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' of https://github.com/NouamaneTazi/diffusers into stable_diff_opti

This commit is contained in:
Nouamane Tazi
2022-09-13 17:36:44 +00:00
3 changed files with 12 additions and 21 deletions

View File

@@ -35,7 +35,7 @@ def get_timestep_embedding(
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = torch.exp(exponent).to(device=timesteps.device, non_blocking=True)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings

View File

@@ -178,7 +178,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device, non_blocking=True))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -190,7 +190,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device, non_blocking=True))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -208,7 +208,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(self.device, non_blocking=True)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
@@ -217,8 +217,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.timesteps = self.scheduler.timesteps.to(self.device)
self.scheduler.timesteps = torch.tensor(self.scheduler.timesteps, device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

View File

@@ -147,19 +147,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self.prk_timesteps = np.array([])
self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
::-1
].copy()
self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy()
else:
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
self.plms_timesteps = self._timesteps[:-3][
::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
self.plms_timesteps = self._timesteps[:-3][::-1].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps])
self.ets = []
self.counter = 0
@@ -216,13 +212,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 #TODO: check if we still have condition in cuda graph
prev_timestep = torch.max(torch.tensor(timestep - diff_to_prev), self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4]
if self.counter % 4 == 0:
@@ -283,7 +274,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
prev_timestep = torch.max(timestep - self.config.num_train_timesteps // self.num_inference_steps, torch.tensor(0))
if self.counter != 1:
self.ets.append(model_output)