From 2fa9c698eae2890ac5f8e367ca80532ecf94df9a Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 14:44:11 +0000 Subject: [PATCH] hardcore whats needed for jitting --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1272fe64e7..b6034bed8a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -195,7 +195,7 @@ class StableDiffusionPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -207,7 +207,7 @@ class StableDiffusionPipeline(DiffusionPipeline): uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -219,8 +219,8 @@ class StableDiffusionPipeline(DiffusionPipeline): # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_device = "cuda" + latents_shape = (batch_size, 4, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, @@ -259,7 +259,7 @@ class StableDiffusionPipeline(DiffusionPipeline): latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case # perform guidance if do_classifier_free_guidance: @@ -280,9 +280,9 @@ class StableDiffusionPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda") image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype)) - + if output_type == "pil": image = self.numpy_to_pil(image)