1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Würstchen model (#3849)

* initial

* initial

* added initial convert script for paella vqmodel

* initial wuerstchen pipeline

* add LayerNorm2d

* added modules

* fix typo

* use model_v2

* embed clip caption amd negative_caption

* fixed name of var

* initial modules in one place

* WuerstchenPriorPipeline

* inital shape

* initial denoising prior loop

* fix output

* add WuerstchenPriorPipeline to __init__.py

* use the noise ratio in the Prior

* try to save pipeline

* save_pretrained working

* Few additions

* add _execution_device

* shape is int

* fix batch size

* fix shape of ratio

* fix shape of ratio

* fix output dataclass

* tests folder

* fix formatting

* fix float16 + started with generator

* Update pipeline_wuerstchen.py

* removed vqgan code

* add WuerstchenGeneratorPipeline

* fix WuerstchenGeneratorPipeline

* fix docstrings

* fix imports

* convert generator pipeline

* fix convert

* Work on Generator Pipeline. WIP

* Pipeline works with our diffuzz code

* apply scale factor

* removed vqgan.py

* use cosine schedule

* redo the denoising loop

* Update src/diffusers/models/resnet.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* use torch.lerp

* use warp-diffusion org

* clip_sample=False,

* some refactoring

* use model_v3_stage_c

* c_cond size

* use clip-bigG

* allow stage b clip to be None

* add dummy

* würstchen scheduler

* minor changes

* set clip=None in the pipeline

* fix attention mask

* add attention_masks to text_encoder

* make fix-copies

* add back clip

* add text_encoder

* gen_text_encoder and tokenizer

* fix import

* updated pipeline test

* undo changes to pipeline test

* nip

* fix typo

* fix output name

* set guidance_scale=0 and remove diffuze

* fix doc strings

* make style

* nip

* removed unused

* initial docs

* rename

* toc

* cleanup

* remvoe test script

* fix-copies

* fix multi images

* remove dup

* remove unused modules

* undo changes for debugging

* no  new line

* remove dup conversion script

* fix doc string

* cleanup

* pass default args

* dup permute

* fix some tests

* fix prepare_latents

* move Prior class to modules

* offload only the text encoder and vqgan

* fix resolution calculation for prior

* nip

* removed testing script

* fix shape

* fix argument to set_timesteps

* do not change .gitignore

* fix resolution calculations + readme

* resolution calculation fix + readme

* small fixes

* Add combined pipeline

* rename generator -> decoder

* Update .gitignore

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* removed efficient_net

* create combined WuerstchenPipeline

* make arguments consistent with VQ model

* fix var names

* no need to return text_encoder_hidden_states

* add latent_dim_scale to config

* split model into its own file

* add WuerschenPipeline to docs

* remove unused latent_size

* register latent_dim_scale

* update script

* update docstring

* use Attention preprocessor

* concat with normed input

* fix-copies

* add docs

* fix test

* fix style

* add to cpu_offloaded_model

* updated type

* remove 1-line func

* updated type

* initial decoder test

* formatting

* formatting

* fix autodoc link

* num_inference_steps is int

* remove comments

* fix example in docs

* Update src/diffusers/pipelines/wuerstchen/diffnext.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* rename layernorm to WuerstchenLayerNorm

* rename DiffNext to WuerstchenDiffNeXt

* added comment about MixingResidualBlock

* move paella vq-vae to pipelines' folder

* initial decoder test

* increased test_float16_inference expected diff

* self_attn is always true

* more passing decoder tests

* batch image_embeds

* fix failing tests

* set the correct dtype

* relax inference test

* update prior

* added combined pipeline test

* faster test

* faster test

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix issues from review

* update wuerstchen.md + change generator name

* resolve issues

* fix copied from usage and add back batch_size

* fix API

* fix arguments

* fix combined test

* Added timesteps argument + fixes

* Update tests/pipelines/test_pipelines_common.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

* up

* Fix more

* failing tests

* up

* up

* correct naming

* correct docs

* correct docs

* fix test params

* correct docs

* fix classifier free guidance

* fix classifier free guidance

* fix more

* fix all

* make tests faster

---------

Co-authored-by: Dominic Rampas <d6582533@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dominic Rampas <61938694+dome272@users.noreply.github.com>
This commit is contained in:
Kashif Rasul
2023-09-06 16:15:51 +02:00
committed by GitHub
parent b76274cb53
commit 541bb6ee63
26 changed files with 2838 additions and 15 deletions

View File

@@ -310,6 +310,8 @@
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion
title: VQ Diffusion
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
- sections:
- local: api/schedulers/overview

View File

@@ -0,0 +1,140 @@
# Würstchen
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9">
[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville.
The abstract from the paper is:
*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.*
## Würstchen v2 comes to Diffusers
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
- Higher resolution (1024x1024 up to 2048x2048)
- Faster inference
- Multi Aspect Resolution Sampling
- Better quality
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
- v2-base
- v2-aesthetic
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
A comparison can be seen here:
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500>
## Text-to-Image Generation
For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows:
```python
import torch
from diffusers import AutoPipelineForText2Image
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
pipeline = AutoPipelineForText2Image.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
).to(device)
caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""
output = pipeline(
prompt=caption,
height=1024,
width=1024,
negative_prompt=negative_prompt,
prior_guidance_scale=4.0,
decoder_guidance_scale=0.0,
num_images_per_prompt=num_images_per_prompt,
output_type="pil",
).images
```
For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637).
```python
import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-diffusion/wuerstchen-prior", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
).to(device)
caption = "A captivating artwork of a mysterious stone golem"
negative_prompt = ""
prior_output = prior_pipeline(
prompt=caption,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=caption,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
guidance_scale=0.0,
output_type="pil",
).images
```
## Speed-Up Inference
You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x:
```python
pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
```
## Limitations
- Due to the high compression employed by Würstchen, generations can lack a good amount
of detail. To our human eye, this is especially noticeable in faces, hands etc.
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
after 1024x1024 is 1152x1152
- The model lacks the ability to render correct text in images
- The model often does not achieve photorealism
- Difficult compositional prompts are hard for the model
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
## WuerschenPipeline
[[autodoc]] WuerstchenCombinedPipeline
- all
- __call__
## WuerstchenPriorPipeline
[[autodoc]] WuerstchenDecoderPipeline
- all
- __call__
## WuerstchenPriorPipelineOutput
[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput
## WuerstchenDecoderPipeline
[[autodoc]] WuerstchenDecoderPipeline
- all
- __call__

View File

@@ -0,0 +1,115 @@
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
import os
import torch
from transformers import AutoTokenizer, CLIPTextModel
from vqgan import VQModel
from diffusers import (
DDPMWuerstchenScheduler,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
model_path = "models/"
device = "cpu"
paella_vqmodel = VQModel()
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
paella_vqmodel.load_state_dict(state_dict)
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
state_dict.pop("vquantizer.codebook.weight")
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
vqmodel.load_state_dict(state_dict)
# Clip Text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# Generator
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
deocder = WuerstchenDiffNeXt()
deocder.load_state_dict(state_dict)
# Prior
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
prior_model.load_state_dict(state_dict)
# scheduler
scheduler = DDPMWuerstchenScheduler()
# Prior pipeline
prior_pipeline = WuerstchenPriorPipeline(
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
)
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
# Wuerstchen pipeline
wuerstchen_pipeline = WuerstchenCombinedPipeline(
# Decoder
text_encoder=gen_text_encoder,
tokenizer=gen_tokenizer,
decoder=deocder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_tokenizer=tokenizer,
prior_text_encoder=text_encoder,
prior=prior_model,
prior_scheduler=scheduler,
)
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")

View File

@@ -88,6 +88,7 @@ else:
DDIMScheduler,
DDPMParallelScheduler,
DDPMScheduler,
DDPMWuerstchenScheduler,
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
@@ -216,6 +217,9 @@ else:
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:

View File

@@ -188,7 +188,7 @@ class Attention(nn.Module):
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(

View File

@@ -52,7 +52,7 @@ class Encoder(nn.Module):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv2d(
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,

View File

@@ -132,7 +132,7 @@ class VQModel(ModelMixin, ConfigMixin):
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
quant, _, _ = self.quantize(h)
else:
quant = h
quant2 = self.post_quant_conv(quant)

View File

@@ -132,6 +132,7 @@ else:
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline
try:

View File

@@ -52,6 +52,7 @@ from .stable_diffusion_xl import (
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -63,6 +64,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky22", KandinskyV22CombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
]
)
@@ -93,6 +95,7 @@ _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky", KandinskyPipeline),
("kandinsky22", KandinskyV22Pipeline),
("wuerstchen", WuerstchenDecoderPipeline),
]
)
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
@@ -305,8 +308,6 @@ class AutoPipelineForText2Image(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -316,8 +317,6 @@ class AutoPipelineForText2Image(ConfigMixin):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
@@ -580,8 +579,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -591,8 +588,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
@@ -856,8 +851,6 @@ class AutoPipelineForInpainting(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -867,8 +860,6 @@ class AutoPipelineForInpainting(ConfigMixin):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)

View File

@@ -0,0 +1,10 @@
from ...utils import is_torch_available, is_transformers_available
if is_transformers_available() and is_torch_available():
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2022 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput
from ...utils import apply_forward_hook
class MixingResidualBlock(nn.Module):
"""
Residual block with mixing used by Paella's VQ-VAE.
"""
def __init__(self, inp_channels, embed_dim):
super().__init__()
# depthwise
self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels)
)
# channelwise
self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels)
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
def forward(self, x):
mods = self.gammas
x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1]
x = x + self.depthwise(x_temp) * mods[2]
x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class PaellaVQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from Paella model.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image.
levels (int, *optional*, defaults to 2): Number of levels in the model.
bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model.
embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model.
latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model.
num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE.
scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space.
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_down_scale_factor: int = 2,
levels: int = 2,
bottleneck_blocks: int = 12,
embed_dim: int = 384,
latent_channels: int = 4,
num_vq_embeddings: int = 8192,
scale_factor: float = 0.3764,
):
super().__init__()
c_levels = [embed_dim // (2**i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(up_down_scale_factor),
nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1),
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = MixingResidualBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(
nn.Sequential(
nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1
)
)
self.down_blocks = nn.Sequential(*down_blocks)
# Vector Quantizer
self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25)
# Decoder blocks
up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(
c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1
)
)
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1),
nn.PixelShuffle(up_down_scale_factor),
)
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.in_block(x)
h = self.down_blocks(h)
if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
@apply_forward_hook
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if not force_not_quantize:
quant, _, _ = self.vquantizer(h)
else:
quant = h
x = self.up_blocks(quant)
dec = self.out_block(x)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from ...models.attention_processor import Attention
class WuerstchenLayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
return x.permute(0, 3, 1, 2)
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
self.mapper = nn.Linear(c_timestep, c * 2)
def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
return x * (1 + a) + b
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)
def forward(self, x, x_skip=None):
x_res = x
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
x = self.channelwise(x).permute(0, 3, 1, 2)
return x + x_res
# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
class GlobalResponseNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * stand_div_norm) + self.beta + x
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)
norm_x = self.norm(x)
if self.self_attn:
batch_size, channel, _, _ = x.shape
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
x = x + self.attention(norm_x, encoder_hidden_states=kv)
return x

View File

@@ -0,0 +1,254 @@
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm
class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
c_in=4,
c_out=4,
c_r=64,
patch_size=2,
c_cond=1024,
c_hidden=[320, 640, 1280, 1280],
nhead=[-1, 10, 20, 20],
blocks=[4, 4, 14, 4],
level_config=["CT", "CTA", "CTA", "CTA"],
inject_effnet=[False, True, True, True],
effnet_embd=16,
clip_embd=1024,
kernel_size=3,
dropout=0.1,
):
super().__init__()
self.c_r = c_r
self.c_cond = c_cond
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
# CONDITIONING
self.clip_mapper = nn.Linear(clip_embd, c_cond)
self.effnet_mappers = nn.ModuleList(
[
nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None
for inject in inject_effnet + list(reversed(inject_effnet))
]
)
self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6),
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
if block_type == "C":
return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
elif block_type == "A":
return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout)
elif block_type == "T":
return TimestepBlock(c_hidden, c_r)
else:
raise ValueError(f"Block type {block_type} not supported")
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
for i in range(len(c_hidden)):
down_block = nn.ModuleList()
if i > 0:
down_block.append(
nn.Sequential(
WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
)
)
for _ in range(blocks[i]):
for block_type in level_config[i]:
c_skip = c_cond if inject_effnet[i] else 0
down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i]))
self.down_blocks.append(down_block)
# -- up blocks
self.up_blocks = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
up_block = nn.ModuleList()
for j in range(blocks[i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
c_skip += c_cond if inject_effnet[i] else 0
up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i]))
if i > 0:
up_block.append(
nn.Sequential(
WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6),
nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
)
)
self.up_blocks.append(up_block)
# OUTPUT
self.clf = nn.Sequential(
WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
self.apply(self._init_weights)
def _init_weights(self, m):
# General init
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
for mapper in self.effnet_mappers:
if mapper is not None:
nn.init.normal_(mapper.weight, std=0.02) # conditionings
nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
nn.init.constant_(self.clf[1].weight, 0) # outputs
# blocks
for level_block in self.down_blocks + self.up_blocks:
for block in level_block:
if isinstance(block, ResBlockStageB):
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks))
elif isinstance(block, TimestepBlock):
nn.init.constant_(block.mapper.weight, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb.to(dtype=r.dtype)
def gen_c_embeddings(self, clip):
clip = self.clip_mapper(clip)
clip = self.seq_norm(clip)
return clip
def _down_encode(self, x, r_embed, effnet, clip=None):
level_outputs = []
for i, down_block in enumerate(self.down_blocks):
effnet_c = None
for block in down_block:
if isinstance(block, ResBlockStageB):
if effnet_c is None and self.effnet_mappers[i] is not None:
dtype = effnet.dtype
effnet_c = self.effnet_mappers[i](
nn.functional.interpolate(
effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
).to(dtype)
)
skip = effnet_c if self.effnet_mappers[i] is not None else None
x = block(x, skip)
elif isinstance(block, AttnBlock):
x = block(x, clip)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, effnet, clip=None):
x = level_outputs[0]
for i, up_block in enumerate(self.up_blocks):
effnet_c = None
for j, block in enumerate(up_block):
if isinstance(block, ResBlockStageB):
if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None:
dtype = effnet.dtype
effnet_c = self.effnet_mappers[len(self.down_blocks) + i](
nn.functional.interpolate(
effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
).to(dtype)
)
skip = level_outputs[i] if j == 0 and i > 0 else None
if effnet_c is not None:
if skip is not None:
skip = torch.cat([skip, effnet_c], dim=1)
else:
skip = effnet_c
x = block(x, skip)
elif isinstance(block, AttnBlock):
x = block(x, clip)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
return x
def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True):
if x_cat is not None:
x = torch.cat([x, x_cat], dim=1)
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r)
if clip is not None:
clip = self.gen_c_embeddings(clip)
# Model Blocks
x_in = x
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, effnet, clip)
x = self._up_decode(level_outputs, r_embed, effnet, clip)
a, b = self.clf(x).chunk(2, dim=1)
b = b.sigmoid() * (1 - eps * 2) + eps
if return_noise:
return (x_in - a) / b
else:
return a, b
class ResBlockStageB(nn.Module):
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c + c_skip, c * 4),
nn.GELU(),
GlobalResponseNorm(c * 4),
nn.Dropout(dropout),
nn.Linear(c * 4, c),
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res

View File

@@ -0,0 +1,72 @@
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
class WuerstchenPrior(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
self.c_r = c_r
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
nn.Linear(c_cond, c),
nn.LeakyReLU(0.2),
nn.Linear(c, c),
)
self.blocks = nn.ModuleList()
for _ in range(depth):
self.blocks.append(ResBlock(c, dropout=dropout))
self.blocks.append(TimestepBlock(c, c_r))
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
nn.Conv2d(c, c_in * 2, kernel_size=1),
)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb.to(dtype=r.dtype)
def forward(self, x, r, c):
x_in = x
x = self.projection(x)
c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r)
for block in self.blocks:
if isinstance(block, AttnBlock):
x = block(x, c_embed)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
a, b = self.out(x).chunk(2, dim=1)
return (x_in - a) / ((1 - b).abs() + 1e-5)

View File

@@ -0,0 +1,399 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain(
... "warp-diffusion/wuerstchen", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
```
"""
class WuerstchenDecoderPipeline(DiffusionPipeline):
"""
Pipeline for generating images from the Wuerstchen model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer (`CLIPTokenizer`):
The CLIP tokenizer.
text_encoder (`CLIPTextModel`):
The CLIP text encoder.
decoder ([`WuerstchenDiffNeXt`]):
The WuerstchenDiffNeXt unet decoder.
vqgan ([`PaellaVQModel`]):
The VQGAN model.
scheduler ([`DDPMWuerstchenScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
latent_dim_scale (float, `optional`, defaults to 10.67):
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
width=int(24*10.67)=256 in order to match the training conditions.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
decoder: WuerstchenDiffNeXt,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
latent_dim_scale: float = 10.67,
) -> None:
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
)
self.register_to_config(latent_dim_scale=latent_dim_scale)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.decoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.prior_hook = hook
_, hook = cpu_offload_with_hook(self.vqgan, device, prev_module_hook=self.prior_hook)
self.final_offload_hook = hook
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_text_encoder_hidden_states = None
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds_text_encoder_output = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
)
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# done duplicates
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
def check_inputs(
self,
image_embeddings,
prompt,
negative_prompt,
num_inference_steps,
do_classifier_free_guidance,
device,
dtype,
):
if not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0)
if isinstance(image_embeddings, np.ndarray):
image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype)
if not isinstance(image_embeddings, torch.Tensor):
raise TypeError(
f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}."
)
if not isinstance(num_inference_steps, int):
raise TypeError(
f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
return image_embeddings, prompt, negative_prompt, num_inference_steps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 12,
timesteps: Optional[List[float]] = None,
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
Image Embeddings either extracted from an image or generated by a Prior Model.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_inference_steps (`int`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
embeddings.
"""
# 0. Define commonly used variables
device = self._execution_device
dtype = self.decoder.dtype
do_classifier_free_guidance = guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct
image_embeddings, prompt, negative_prompt, num_inference_steps = self.check_inputs(
image_embeddings, prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, device, dtype
)
# 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
)
# 3. Determine latent shape of latents
latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale)
latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width)
# 4. Prepare and set timesteps
if timesteps is not None:
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]):
ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if do_classifier_free_guidance
else image_embeddings
)
# 7. Denoise latents
predicted_latents = self.decoder(
torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
effnet=effnet,
clip=text_encoder_hidden_states,
)
# 8. Check for classifier free guidance and apply it
if do_classifier_free_guidance:
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale)
# 9. Renoise latents to next timestep
latents = self.scheduler.step(
model_output=predicted_latents,
timestep=ratio,
sample=latents,
generator=generator,
).prev_sample
# 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}")
if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().numpy()
elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = self.numpy_to_pil(images)
if not return_dict:
return images
return ImagePipelineOutput(images)

View File

@@ -0,0 +1,248 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusions import WuerstchenCombinedPipeline
>>> pipe = WuerstchenCombinedPipeline.from_pretrained(
... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
"""
class WuerstchenCombinedPipeline(DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Wuerstchen
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs.
text_encoder (`CLIPTextModel`):
The decoder text encoder to be used for text inputs.
decoder (`WuerstchenDiffNeXt`):
The decoder model to be used for decoder image generation pipeline.
scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline.
prior_tokenizer (`CLIPTokenizer`):
The prior tokenizer to be used for text inputs.
prior_text_encoder (`CLIPTextModel`):
The prior text encoder to be used for text inputs.
prior (`WuerstchenPrior`):
The prior model to be used for prior pipeline.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
"""
_load_connected_pipes = True
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
decoder: WuerstchenDiffNeXt,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
prior_tokenizer: CLIPTokenizer,
prior_text_encoder: CLIPTextModel,
prior_prior: WuerstchenPrior,
prior_scheduler: DDPMWuerstchenScheduler,
):
super().__init__()
self.register_modules(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
prior_prior=prior_prior,
prior_text_encoder=prior_text_encoder,
prior_tokenizer=prior_tokenizer,
prior_scheduler=prior_scheduler,
)
self.prior_pipe = WuerstchenPriorPipeline(
prior=prior_prior,
text_encoder=prior_text_encoder,
tokenizer=prior_tokenizer,
scheduler=prior_scheduler,
)
self.decoder_pipe = WuerstchenDecoderPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
)
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self.prior_pipe.enable_model_cpu_offload()
self.decoder_pipe.enable_model_cpu_offload()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
def set_progress_bar_config(self, **kwargs):
self.prior_pipe.set_progress_bar_config(**kwargs)
self.decoder_pipe.set_progress_bar_config(**kwargs)
@torch.no_grad()
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
height: int = 512,
width: int = 512,
prior_guidance_scale: float = 4.0,
prior_num_inference_steps: int = 60,
num_inference_steps: int = 12,
prior_timesteps: Optional[List[float]] = None,
timesteps: Optional[List[float]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps`
num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps`
prior_timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
`prior_num_inference_steps` timesteps are used. Must be in descending order.
timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
`decoder_num_inference_steps` timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=prior_num_inference_steps,
timesteps=prior_timesteps,
generator=generator,
latents=latents,
guidance_scale=prior_guidance_scale,
output_type="pt",
return_dict=False,
)
image_embeddings = prior_outputs[0]
outputs = self.decoder_pipe(
prompt=prompt,
image_embeddings=image_embeddings,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
generator=generator,
guidance_scale=guidance_scale,
output_type=output_type,
return_dict=return_dict,
)
return outputs

View File

@@ -0,0 +1,402 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import ceil
from typing import List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import (
BaseOutput,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import WuerstchenPriorPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
```
"""
@dataclass
class WuerstchenPriorPipelineOutput(BaseOutput):
"""
Output class for WuerstchenPriorPipeline.
Args:
image_embeddings (`torch.FloatTensor` or `np.ndarray`)
Prior image embeddings for text prompt
"""
image_embeddings: Union[torch.FloatTensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline):
"""
Pipeline for generating image prior for Wuerstchen.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
prior ([`Prior`]):
The canonical unCLIP prior to approximate the image embedding from the text embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`DDPMWuerstchenScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
prior: WuerstchenPrior,
scheduler: DDPMWuerstchenScheduler,
latent_mean: float = 42.0,
latent_std: float = 1.0,
resolution_multiple: float = 42.67,
) -> None:
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
prior=prior,
scheduler=scheduler,
)
self.register_to_config(
latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple
)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.prior_hook = hook
_, hook = cpu_offload_with_hook(self.prior, device, prev_module_hook=self.prior_hook)
self.final_offload_hook = hook
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_text_encoder_hidden_states = None
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds_text_encoder_output = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
)
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# done duplicates
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
def check_inputs(
self,
prompt,
negative_prompt,
num_inference_steps,
do_classifier_free_guidance,
batch_size,
):
if not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
if not isinstance(num_inference_steps, int):
raise TypeError(
f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
batch_size = len(prompt) if isinstance(prompt, list) else 1
return prompt, negative_prompt, num_inference_steps, batch_size
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 30,
timesteps: List[float] = None,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated image embeddings.
"""
# 0. Define commonly used variables
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
batch_size = len(prompt) if isinstance(prompt, list) else 1
# 1. Check inputs. Raise error if not correct
prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(
prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size
)
# 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
)
# 3. Determine latent shape of image embeddings
dtype = text_encoder_hidden_states.dtype
latent_height = ceil(height / self.config.resolution_multiple)
latent_width = ceil(width / self.config.resolution_multiple)
num_channels = self.prior.config.c_in
effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width)
# 4. Prepare and set timesteps
if timesteps is not None:
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]):
ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings
predicted_image_embedding = self.prior(
torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
c=text_encoder_hidden_states,
)
# 8. Check for classifier free guidance and apply it
if do_classifier_free_guidance:
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
predicted_image_embedding = torch.lerp(
predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale
)
# 9. Renoise latents to next timestep
latents = self.scheduler.step(
model_output=predicted_image_embedding,
timestep=ratio,
sample=latents,
generator=generator,
).prev_sample
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
if output_type == "np":
latents = latents.cpu().numpy()
if not return_dict:
return (latents,)
return WuerstchenPriorPipelineOutput(latents)

View File

@@ -34,6 +34,7 @@ else:
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2022 Pablo Pernías MIT License
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin
@dataclass
class DDPMWuerstchenSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
scaler (`float`): ....
s (`float`): ....
"""
@register_to_config
def __init__(
self,
scaler: float = 1.0,
s: float = 0.008,
):
self.scaler = scaler
self.s = torch.tensor([s])
self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def _alpha_cumprod(self, t, device):
if self.scaler > 1:
t = 1 - (1 - t) ** self.scaler
elif self.scaler < 1:
t = t**self.scaler
alpha_cumprod = torch.cos(
(t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
) ** 2 / self._init_alpha_cumprod.to(device)
return alpha_cumprod.clamp(0.0001, 0.9999)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(
self,
num_inference_steps: int = None,
timesteps: Optional[List[int]] = None,
device: Union[str, torch.device] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`Dict[float, int]`):
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
`timesteps` must be `None`.
device (`str` or `torch.device`, optional):
the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
"""
if timesteps is None:
timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
if not isinstance(timesteps, torch.Tensor):
timesteps = torch.Tensor(timesteps).to(device)
self.timesteps = timesteps
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
Returns:
[`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
dtype = model_output.dtype
device = model_output.device
t = timestep
prev_t = self.previous_timestep(t)
alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
alpha = alpha_cumprod / alpha_cumprod_prev
mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
if not return_dict:
return (pred.to(dtype),)
return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype))
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
index = (self.timesteps - timestep[0]).abs().argmin().item()
prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
return prev_t

View File

@@ -615,6 +615,21 @@ class DDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DDPMWuerstchenScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DEISMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1260,3 +1260,48 @@ class VQDiffusionPipeline(metaclass=DummyObject):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WuerstchenCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WuerstchenDecoderPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WuerstchenPriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

View File

@@ -535,6 +535,8 @@ class PipelineTesterMixin:
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))

View File

View File

@@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenCombinedPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"guidance_scale",
"negative_prompt",
"num_inference_steps",
"return_dict",
"prior_num_inference_steps",
"output_type",
"return_dict",
]
test_xformers_attention = True
@property
def text_embedder_hidden_size(self):
return 32
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2}
model = WuerstchenPrior(**model_kwargs)
return model.eval()
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_prior_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config).eval()
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"c_cond": self.text_embedder_hidden_size,
"c_hidden": [320],
"nhead": [-1],
"blocks": [4],
"level_config": ["CT"],
"clip_embd": self.text_embedder_hidden_size,
"inject_effnet": [False],
}
model = WuerstchenDiffNeXt(**model_kwargs)
return model.eval()
def get_dummy_components(self):
prior = self.dummy_prior
prior_text_encoder = self.dummy_prior_text_encoder
scheduler = DDPMWuerstchenScheduler()
tokenizer = self.dummy_tokenizer
text_encoder = self.dummy_text_encoder
decoder = self.dummy_decoder
vqgan = self.dummy_vqgan
components = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"decoder": decoder,
"vqgan": vqgan,
"scheduler": scheduler,
"prior_prior": prior,
"prior_text_encoder": prior_text_encoder,
"prior_tokenizer": tokenizer,
"prior_scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"prior_guidance_scale": 4.0,
"guidance_scale": 4.0,
"num_inference_steps": 2,
"prior_num_inference_steps": 2,
"output_type": "np",
"height": 128,
"width": 128,
}
return inputs
def test_wuerstchen(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
assert (
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
@unittest.skip(reason="flakey and float16 requires CUDA")
def test_float16_inference(self):
super().test_float16_inference()

View File

@@ -0,0 +1,196 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenDecoderPipeline
params = ["prompt"]
batch_params = ["image_embeddings", "prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"c_cond": self.text_embedder_hidden_size,
"c_hidden": [320],
"nhead": [-1],
"blocks": [4],
"level_config": ["CT"],
"clip_embd": self.text_embedder_hidden_size,
"inject_effnet": [False],
}
model = WuerstchenDiffNeXt(**model_kwargs)
return model.eval()
def get_dummy_components(self):
decoder = self.dummy_decoder
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
vqgan = self.dummy_vqgan
scheduler = DDPMWuerstchenScheduler()
components = {
"decoder": decoder,
"vqgan": vqgan,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"latent_dim_scale": 4.0,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"image_embeddings": torch.ones((1, 4, 4, 4), device=device),
"prompt": "horse",
"generator": generator,
"guidance_scale": 1.0,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_wuerstchen_decoder(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@skip_mps
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
test_mean_pixel_difference = False
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference,
relax_max_difference=relax_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@unittest.skip(reason="bf16 not supported and requires CUDA")
def test_float16_inference(self):
super().test_float16_inference()

View File

@@ -0,0 +1,194 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import WuerstchenPrior
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenPriorPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"generator",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config).eval()
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"c_in": 2,
"c": 8,
"depth": 2,
"c_cond": 32,
"c_r": 8,
"nhead": 2,
}
model = WuerstchenPrior(**model_kwargs)
return model.eval()
def get_dummy_components(self):
prior = self.dummy_prior
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
scheduler = DDPMWuerstchenScheduler()
components = {
"prior": prior,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"guidance_scale": 4.0,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_wuerstchen_prior(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.image_embeddings
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
image_slice = image[0, 0, 0, -10:]
image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
assert image.shape == (1, 2, 24, 24)
expected_slice = np.array(
[
-7172.837,
-3438.855,
-1093.312,
388.8835,
-7471.467,
-7998.1206,
-5328.259,
218.00089,
-2731.5745,
-8056.734,
],
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@skip_mps
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
test_mean_pixel_difference = False
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference,
relax_max_difference=relax_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
expected_max_diff=1e-1,
)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@unittest.skip(reason="flaky for now")
def test_float16_inference(self):
super().test_float16_inference()