mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' of https://github.com/NouamaneTazi/diffusers into stable_diff_opti
This commit is contained in:
@@ -35,7 +35,7 @@ def get_timestep_embedding(
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = torch.exp(exponent).to(device=timesteps.device, non_blocking=True)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
|
||||
@@ -178,7 +178,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device, non_blocking=True))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -190,7 +190,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device, non_blocking=True))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -208,7 +208,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
latents = latents.to(self.device, non_blocking=True)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
@@ -217,8 +217,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
self.scheduler.timesteps = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
self.scheduler.timesteps = torch.tensor(self.scheduler.timesteps, device=self.device)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
@@ -147,19 +147,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
|
||||
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
|
||||
self.prk_timesteps = np.array([])
|
||||
self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
|
||||
::-1
|
||||
].copy()
|
||||
self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy()
|
||||
else:
|
||||
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
|
||||
self.plms_timesteps = self._timesteps[:-3][
|
||||
::-1
|
||||
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
|
||||
self.plms_timesteps = self._timesteps[:-3][::-1].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
|
||||
|
||||
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
|
||||
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps])
|
||||
|
||||
self.ets = []
|
||||
self.counter = 0
|
||||
@@ -216,13 +212,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
|
||||
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
|
||||
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 #TODO: check if we still have condition in cuda graph
|
||||
prev_timestep = torch.max(torch.tensor(timestep - diff_to_prev), self.prk_timesteps[-1])
|
||||
timestep = self.prk_timesteps[self.counter // 4 * 4]
|
||||
|
||||
if self.counter % 4 == 0:
|
||||
@@ -283,7 +274,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
|
||||
prev_timestep = torch.max(timestep - self.config.num_train_timesteps // self.num_inference_steps, torch.tensor(0))
|
||||
|
||||
if self.counter != 1:
|
||||
self.ets.append(model_output)
|
||||
|
||||
Reference in New Issue
Block a user