mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
test the full pipeline
This commit is contained in:
@@ -28,11 +28,11 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
VersatileDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
SCHEDULER_CONFIG = Namespace(
|
||||
@@ -44,7 +44,7 @@ SCHEDULER_CONFIG = Namespace(
|
||||
}
|
||||
)
|
||||
|
||||
UNET_IMAGE_CONFIG = Namespace(
|
||||
IMAGE_UNET_CONFIG = Namespace(
|
||||
**{
|
||||
"input_channels": 4,
|
||||
"model_channels": 320,
|
||||
@@ -58,7 +58,7 @@ UNET_IMAGE_CONFIG = Namespace(
|
||||
}
|
||||
)
|
||||
|
||||
UNET_TEXT_CONFIG = Namespace(
|
||||
TEXT_UNET_CONFIG = Namespace(
|
||||
**{
|
||||
"input_channels": 768,
|
||||
"model_channels": 320,
|
||||
@@ -750,21 +750,20 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
if args.unet_checkpoint_path is not None:
|
||||
unet_image_config = create_unet_diffusers_config(UNET_IMAGE_CONFIG)
|
||||
image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG)
|
||||
checkpoint = torch.load(args.unet_checkpoint_path)
|
||||
converted_unet_image_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, unet_image_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
|
||||
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
|
||||
)
|
||||
unet_image = UNet2DConditionModel(**unet_image_config)
|
||||
unet_image.load_state_dict(converted_unet_image_checkpoint)
|
||||
unet_image.save_pretrained(os.path.join(args.dump_path, "unet_image"))
|
||||
image_unet = UNet2DConditionModel(**image_unet_config)
|
||||
image_unet.load_state_dict(converted_image_unet_checkpoint)
|
||||
|
||||
# unet_text_config = create_unet_diffusers_config(UNET_TEXT_CONFIG)
|
||||
# converted_unet_text_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, unet_text_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
# converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# )
|
||||
# unet_text = UNet2DConditionModel(**unet_text_config)
|
||||
# unet_text.load_state_dict(converted_unet_text_checkpoint)
|
||||
# text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
# text_unet.load_state_dict(converted_text_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
if args.vae_checkpoint_path is not None:
|
||||
@@ -774,28 +773,20 @@ if __name__ == "__main__":
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.save_pretrained(os.path.join(args.dump_path, "vae"))
|
||||
|
||||
# Convert the text model.
|
||||
# text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
# if text_model_type == "FrozenCLIPEmbedder":
|
||||
# text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
# pipe = StableDiffusionPipeline(
|
||||
# vae=vae,
|
||||
# text_encoder=text_model,
|
||||
# tokenizer=tokenizer,
|
||||
# unet=unet,
|
||||
# scheduler=scheduler,
|
||||
# safety_checker=safety_checker,
|
||||
# feature_extractor=feature_extractor,
|
||||
# )
|
||||
# else:
|
||||
# text_config = create_ldm_bert_config(original_config)
|
||||
# text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
# pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
#
|
||||
# pipe.save_pretrained(args.dump_path)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
pipe = VersatileDiffusionPipeline(
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
# text_unet=text_unet,
|
||||
vae=vae,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -71,6 +71,7 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -23,6 +23,7 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .versatile_diffusion import VersatileDiffusionPipeline
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
|
||||
1
src/diffusers/pipelines/versatile_diffusion/__init__.py
Normal file
1
src/diffusers/pipelines/versatile_diffusion/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
|
||||
@@ -0,0 +1,215 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
|
||||
class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
vqvae ([`VQModel`]):
|
||||
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
|
||||
bert ([`LDMBertModel`]):
|
||||
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
|
||||
tokenizer (`transformers.BertTokenizer`):
|
||||
Tokenizer of class
|
||||
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
tokenizer: CLIPTokenizer
|
||||
image_processor: CLIPProcessor
|
||||
text_encoder: CLIPTextModel
|
||||
image_encoder: CLIPVisionModel
|
||||
image_unet: UNet2DConditionModel
|
||||
vae: Union[VQModel, AutoencoderKL]
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
image_processor: CLIPProcessor,
|
||||
text_encoder: CLIPTextModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_unet: UNet2DConditionModel,
|
||||
vae: Union[VQModel, AutoencoderKL],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
def _encode_prompt(self, prompt, do_classifier_free_guidance):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
"""
|
||||
|
||||
def _normalize_embeddings(encoder_output):
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
|
||||
embeds_pooled = encoder_output.text_embeds
|
||||
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))
|
||||
uncond_embeddings = _normalize_embeddings(uncond_embeddings)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))
|
||||
text_embeddings = _normalize_embeddings(text_embeddings)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 1.0,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
r"""
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 256):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 256):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 1.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
|
||||
the, usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance)
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.image_unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
|
||||
extra_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
if not do_classifier_free_guidance:
|
||||
latents_input = latents
|
||||
else:
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.image_unet(latents_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
0
tests/pipelines/versatile_diffusion/__init__.py
Normal file
0
tests/pipelines/versatile_diffusion/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import VersatileDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VersatileDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
Reference in New Issue
Block a user