From dd48802fa58a0d0da88d8dd18202f3ad32c563c3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Jun 2023 21:16:04 +0000 Subject: [PATCH] Fix more --- src/diffusers/models/unet_2d_condition.py | 1 + .../pipeline_stable_diffusion_xl.py | 26 ++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5f53b55286..192f9663f2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -820,6 +820,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb + aug_emb diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index ef759adb8d..0a46aafae1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -15,7 +15,7 @@ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union -from pytorch_lightning import seed_everything +# from pytorch_lightning import seed_everything import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection @@ -117,7 +117,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): scheduler: KarrasDiffusionSchedulers, # safety_checker: StableDiffusionSafetyChecker, # feature_extractor: CLIPImageProcessor, - # requires_safety_checker: bool = True, ): super().__init__() @@ -151,7 +150,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - # self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" @@ -245,7 +243,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - return "cpu" if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): @@ -635,12 +632,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - seed_everything(0) + # seed_everything(0) # 1. Check inputs. Raise error if not correct - # self.check_inputs( - # prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - # ) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -650,8 +647,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): else: batch_size = prompt_embeds.shape[0] - # device = self._execution_device - device = "cpu" + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -697,7 +693,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long) + + # TODO - find better explanations where they come from + add_time_ids = torch.tensor(2 * [[128, 128, 0, 0, 1024, 1024]], dtype=torch.long, device=add_text_embeds.device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -736,8 +734,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline): callback(i, t, latents) if not output_type == "latent": + # CHECK there is problem here (PVP) + # self.vae = self.vae.to(dtype=torch.float32) + # latents = latents.float() image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + has_nsfw_concept = None else: image = latents has_nsfw_concept = None