mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
The 'CLIPFeatureExtractor' class name has been renamed to 'CLIPImageProcessor' in order to comply with future deprecation. This commit includes the necessary changes to the affected files.
262 lines
11 KiB
Python
262 lines
11 KiB
Python
import inspect
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import torch
|
|
from transformers import (
|
|
CLIPImageProcessor,
|
|
CLIPTextModel,
|
|
CLIPTokenizer,
|
|
WhisperForConditionalGeneration,
|
|
WhisperProcessor,
|
|
)
|
|
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DDIMScheduler,
|
|
DiffusionPipeline,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
UNet2DConditionModel,
|
|
)
|
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
from diffusers.utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class SpeechToImagePipeline(DiffusionPipeline):
|
|
def __init__(
|
|
self,
|
|
speech_model: WhisperForConditionalGeneration,
|
|
speech_processor: WhisperProcessor,
|
|
vae: AutoencoderKL,
|
|
text_encoder: CLIPTextModel,
|
|
tokenizer: CLIPTokenizer,
|
|
unet: UNet2DConditionModel,
|
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
|
safety_checker: StableDiffusionSafetyChecker,
|
|
feature_extractor: CLIPImageProcessor,
|
|
):
|
|
super().__init__()
|
|
|
|
if safety_checker is None:
|
|
logger.warning(
|
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
|
)
|
|
|
|
self.register_modules(
|
|
speech_model=speech_model,
|
|
speech_processor=speech_processor,
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
feature_extractor=feature_extractor,
|
|
)
|
|
|
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
|
if slice_size == "auto":
|
|
slice_size = self.unet.config.attention_head_dim // 2
|
|
self.unet.set_attention_slice(slice_size)
|
|
|
|
def disable_attention_slicing(self):
|
|
self.enable_attention_slicing(None)
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
audio,
|
|
sampling_rate=16_000,
|
|
height: int = 512,
|
|
width: int = 512,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[torch.Generator] = None,
|
|
latents: Optional[torch.FloatTensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
callback_steps: int = 1,
|
|
**kwargs,
|
|
):
|
|
inputs = self.speech_processor.feature_extractor(
|
|
audio, return_tensors="pt", sampling_rate=sampling_rate
|
|
).input_features.to(self.device)
|
|
predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
|
|
|
|
prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
|
|
0
|
|
]
|
|
|
|
if isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
|
|
if (callback_steps is None) or (
|
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
|
):
|
|
raise ValueError(
|
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
f" {type(callback_steps)}."
|
|
)
|
|
|
|
# get prompt text embeddings
|
|
text_inputs = self.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=self.tokenizer.model_max_length,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
|
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
|
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
bs_embed, seq_len, _ = text_embeddings.shape
|
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
# corresponds to doing no classifier free guidance.
|
|
do_classifier_free_guidance = guidance_scale > 1.0
|
|
# get unconditional embeddings for classifier free guidance
|
|
if do_classifier_free_guidance:
|
|
uncond_tokens: List[str]
|
|
if negative_prompt is None:
|
|
uncond_tokens = [""] * batch_size
|
|
elif type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif isinstance(negative_prompt, str):
|
|
uncond_tokens = [negative_prompt]
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
else:
|
|
uncond_tokens = negative_prompt
|
|
|
|
max_length = text_input_ids.shape[-1]
|
|
uncond_input = self.tokenizer(
|
|
uncond_tokens,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
|
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
seq_len = uncond_embeddings.shape[1]
|
|
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
# to avoid doing two forward passes
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
|
# get the initial random noise unless the user supplied it
|
|
|
|
# Unlike in other pipelines, latents need to be generated in the target device
|
|
# for 1-to-1 results reproducibility with the CompVis implementation.
|
|
# However this currently doesn't work in `mps`.
|
|
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
|
latents_dtype = text_embeddings.dtype
|
|
if latents is None:
|
|
if self.device.type == "mps":
|
|
# randn does not exist on mps
|
|
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
|
self.device
|
|
)
|
|
else:
|
|
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
|
else:
|
|
if latents.shape != latents_shape:
|
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
latents = latents.to(self.device)
|
|
|
|
# set timesteps
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
|
# Some schedulers like PNDM have timesteps as arrays
|
|
# It's more optimized to move all timesteps to correct device beforehand
|
|
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
|
|
|
# scale the initial noise by the standard deviation required by the scheduler
|
|
latents = latents * self.scheduler.init_noise_sigma
|
|
|
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
|
# and should be between [0, 1]
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
extra_step_kwargs = {}
|
|
if accepts_eta:
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
|
# expand the latents if we are doing classifier free guidance
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
# predict the noise residual
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
|
|
# perform guidance
|
|
if do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
|
|
|
# call the callback, if provided
|
|
if callback is not None and i % callback_steps == 0:
|
|
callback(i, t, latents)
|
|
|
|
latents = 1 / 0.18215 * latents
|
|
image = self.vae.decode(latents).sample
|
|
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
|
|
if output_type == "pil":
|
|
image = self.numpy_to_pil(image)
|
|
|
|
if not return_dict:
|
|
return image
|
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|