diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 1129e4b3bd..db08a124d9 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -175,6 +175,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, clip_guidance_scale: Optional[float] = 100, clip_prompt: Optional[Union[str, List[str]]] = None, num_cutouts: Optional[int] = 4, @@ -203,6 +204,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) if clip_guidance_scale > 0: if clip_prompt is not None: @@ -217,6 +220,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): clip_text_input = text_input.input_ids.to(self.device) text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + # duplicate text embeddings clip for each generation per prompt + text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -225,10 +230,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) + uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -240,18 +245,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - ) + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) @@ -261,17 +268,17 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -299,10 +306,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler.step(noise_pred, t, latents).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents