mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix more
This commit is contained in:
@@ -820,6 +820,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from pytorch_lightning import seed_everything
|
||||
# from pytorch_lightning import seed_everything
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
@@ -117,7 +117,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPImageProcessor,
|
||||
# requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -151,7 +150,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
# self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
@@ -245,7 +243,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
return "cpu"
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
@@ -635,12 +632,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
seed_everything(0)
|
||||
# seed_everything(0)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
# self.check_inputs(
|
||||
# prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
# )
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -650,8 +647,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# device = self._execution_device
|
||||
device = "cpu"
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -697,7 +693,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long)
|
||||
|
||||
# TODO - find better explanations where they come from
|
||||
add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -736,8 +734,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# CHECK there is problem here (PVP)
|
||||
# self.vae = self.vae.to(dtype=torch.float32)
|
||||
# latents = latents.float()
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
Reference in New Issue
Block a user