diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py index 825c31970f..27fe7c1ffa 100644 --- a/scripts/convert_versatile_diffusion_to_diffusers.py +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -28,11 +28,11 @@ from diffusers import ( EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, - StableDiffusionPipeline, UNet2DConditionModel, + VersatileDiffusionPipeline, ) from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer +from transformers import CLIPProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection SCHEDULER_CONFIG = Namespace( @@ -44,7 +44,7 @@ SCHEDULER_CONFIG = Namespace( } ) -UNET_IMAGE_CONFIG = Namespace( +IMAGE_UNET_CONFIG = Namespace( **{ "input_channels": 4, "model_channels": 320, @@ -58,7 +58,7 @@ UNET_IMAGE_CONFIG = Namespace( } ) -UNET_TEXT_CONFIG = Namespace( +TEXT_UNET_CONFIG = Namespace( **{ "input_channels": 768, "model_channels": 320, @@ -750,21 +750,20 @@ if __name__ == "__main__": # Convert the UNet2DConditionModel model. if args.unet_checkpoint_path is not None: - unet_image_config = create_unet_diffusers_config(UNET_IMAGE_CONFIG) + image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG) checkpoint = torch.load(args.unet_checkpoint_path) - converted_unet_image_checkpoint = convert_vd_unet_checkpoint( - checkpoint, unet_image_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema + converted_image_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema ) - unet_image = UNet2DConditionModel(**unet_image_config) - unet_image.load_state_dict(converted_unet_image_checkpoint) - unet_image.save_pretrained(os.path.join(args.dump_path, "unet_image")) + image_unet = UNet2DConditionModel(**image_unet_config) + image_unet.load_state_dict(converted_image_unet_checkpoint) - # unet_text_config = create_unet_diffusers_config(UNET_TEXT_CONFIG) - # converted_unet_text_checkpoint = convert_vd_unet_checkpoint( - # checkpoint, unet_text_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + # text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG) + # converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + # checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema # ) - # unet_text = UNet2DConditionModel(**unet_text_config) - # unet_text.load_state_dict(converted_unet_text_checkpoint) + # text_unet = UNet2DConditionModel(**text_unet_config) + # text_unet.load_state_dict(converted_text_unet_checkpoint) # Convert the VAE model. if args.vae_checkpoint_path is not None: @@ -774,28 +773,20 @@ if __name__ == "__main__": vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - vae.save_pretrained(os.path.join(args.dump_path, "vae")) - # Convert the text model. -# text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] -# if text_model_type == "FrozenCLIPEmbedder": -# text_model = convert_ldm_clip_checkpoint(checkpoint) -# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") -# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") -# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") -# pipe = StableDiffusionPipeline( -# vae=vae, -# text_encoder=text_model, -# tokenizer=tokenizer, -# unet=unet, -# scheduler=scheduler, -# safety_checker=safety_checker, -# feature_extractor=feature_extractor, -# ) -# else: -# text_config = create_ldm_bert_config(original_config) -# text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) -# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") -# pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) -# -# pipe.save_pretrained(args.dump_path) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + pipe = VersatileDiffusionPipeline( + scheduler=scheduler, + tokenizer=tokenizer, + image_processor=image_processor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + # text_unet=text_unet, + vae=vae, + ) + pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 86eda7371f..19558334af 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -71,6 +71,7 @@ if is_torch_available() and is_transformers_available(): StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, + VersatileDiffusionPipeline, VQDiffusionPipeline, ) else: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef4d23e5e6..abb09605e3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -23,6 +23,7 @@ if is_torch_available() and is_transformers_available(): StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, ) + from .versatile_diffusion import VersatileDiffusionPipeline from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 0000000000..cd63bbfc28 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -0,0 +1 @@ +from .pipeline_versatile_diffusion import VersatileDiffusionPipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 0000000000..3e56e4b57e --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,215 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_processor: CLIPProcessor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + vae: Union[VQModel, AutoencoderKL] + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_processor: CLIPProcessor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + vae: Union[VQModel, AutoencoderKL], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_processor=image_processor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + + def _encode_prompt(self, prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def _normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if do_classifier_free_guidance: + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device)) + uncond_embeddings = _normalize_embeddings(uncond_embeddings) + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device)) + text_embeddings = _normalize_embeddings(text_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 256): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 256): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at + the, usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance) + + latents = torch.randn( + (batch_size, self.image_unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if not do_classifier_free_guidance: + latents_input = latents + else: + latents_input = torch.cat([latents] * 2) + + # predict the noise residual + noise_pred = self.image_unet(latents_input, t, encoder_hidden_states=text_embeddings).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/tests/pipelines/versatile_diffusion/__init__.py b/tests/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py new file mode 100644 index 0000000000..e0f7618034 --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionPipeline +from diffusers.utils.testing_utils import require_torch, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch +class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase): + def test_inference_text2img(self): + pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2