1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

hardcore whats needed for jitting

This commit is contained in:
Nouamane Tazi
2022-09-22 14:44:11 +00:00
parent 31c58eadb8
commit 2fa9c698ea

View File

@@ -195,7 +195,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +207,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -219,8 +219,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_device = "cuda"
latents_shape = (batch_size, 4, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
@@ -259,7 +259,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case
# perform guidance
if do_classifier_free_guidance:
@@ -280,9 +280,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda")
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype))
if output_type == "pil":
image = self.numpy_to_pil(image)