mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
hardcore whats needed for jitting
This commit is contained in:
@@ -195,7 +195,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -207,7 +207,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -219,8 +219,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_device = "cuda"
|
||||
latents_shape = (batch_size, 4, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
@@ -259,7 +259,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -280,9 +280,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda")
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype))
|
||||
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user