mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix lpw stable diffusion pipeline compatibility (#1622)
This commit is contained in:
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from diffusers.utils import deprecate, logging
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
re_attention = re.compile(
|
||||
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
if not hasattr(self, "vae_scale_factor"):
|
||||
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
||||
from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from diffusers.utils import deprecate, logging
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.onnx_utils import ORT_TO_NP_TYPE
|
||||
except ImportError:
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
re_attention = re.compile(
|
||||
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
"""
|
||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
self.unet_in_channels = 4
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.numpy()
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.numpy()
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t], dtype=timestep_dtype),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t], dtype=timestep_dtype),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents_orig),
|
||||
torch.from_numpy(noise),
|
||||
t,
|
||||
).numpy()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents_orig),
|
||||
torch.from_numpy(noise),
|
||||
t,
|
||||
).numpy()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user