1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Patrick von Platen
2023-06-23 21:16:04 +00:00
parent 51ab97a2f7
commit dd48802fa5
2 changed files with 15 additions and 12 deletions

View File

@@ -820,6 +820,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb

View File

@@ -15,7 +15,7 @@
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from pytorch_lightning import seed_everything
# from pytorch_lightning import seed_everything
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
@@ -117,7 +117,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
scheduler: KarrasDiffusionSchedulers,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPImageProcessor,
# requires_safety_checker: bool = True,
):
super().__init__()
@@ -151,7 +150,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self):
r"""
@@ -245,7 +243,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
return "cpu"
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
@@ -635,12 +632,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
seed_everything(0)
# seed_everything(0)
# 1. Check inputs. Raise error if not correct
# self.check_inputs(
# prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
# )
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -650,8 +647,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
else:
batch_size = prompt_embeds.shape[0]
# device = self._execution_device
device = "cpu"
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -697,7 +693,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long)
# TODO - find better explanations where they come from
add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -736,8 +734,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
callback(i, t, latents)
if not output_type == "latent":
# CHECK there is problem here (PVP)
# self.vae = self.vae.to(dtype=torch.float32)
# latents = latents.float()
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
has_nsfw_concept = None
else:
image = latents
has_nsfw_concept = None