1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'add_versatile_diffusers' of https://github.com/huggingface/diffusers into add_versatile_diffusers

This commit is contained in:
Patrick von Platen
2022-11-23 12:08:50 +00:00
4 changed files with 191 additions and 39 deletions

View File

@@ -22,7 +22,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils import CONFIG_NAME, BaseOutput
from ..utils.import_utils import is_xformers_available
@@ -666,3 +666,120 @@ class AdaLayerNorm(nn.Module):
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, num_condition_tokens[0]+num_condition_tokens[1], num_features)`
self.num_condition_tokens = (77, 257)
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.num_condition_tokens[i]]
encoded_state = self.transformers[i](input_states, condition_state, timestep, return_dict)[0]
encoded_states.append(encoded_state - input_states)
tokens_start += self.num_condition_tokens[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)
def _set_attention_slice(self, slice_size):
for transformer in self.transformers:
transformer._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for transformer in self.transformers:
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, Transformer2DModel
from .attention import AttentionBlock, Transformer2DModel, DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D

View File

@@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
import PIL
@@ -29,7 +28,7 @@ from transformers import (
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import Transformer2DModel, Transformer2DModelOutput
from ...models.attention import DualTransformer2DModel, Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
@@ -94,12 +93,32 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
image_transformer = self.image_unet.get_submodule(parent_name)[index]
text_transformer = self.text_unet.get_submodule(parent_name)[index]
config = image_transformer.config
dual_transformer = DualTransformer2DModel(
image_transformer, text_transformer, mix_ratio=mix_ratio, condition_types=condition_types
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
in_channels=config.in_channels,
num_layers=config.num_layers,
dropout=config.dropout,
norm_num_groups=config.norm_num_groups,
cross_attention_dim=config.cross_attention_dim,
attention_bias=config.attention_bias,
sample_size=config.sample_size,
num_vector_embeds=config.num_vector_embeds,
activation_fn=config.activation_fn,
num_embeds_ada_norm=config.num_embeds_ada_norm,
)
for i, type in enumerate(condition_types):
if type == "image":
dual_transformer.transformers[i] = image_transformer
else:
dual_transformer.transformers[i] = text_transformer
dual_transformer.mix_ratio = mix_ratio
self.image_unet.get_submodule(parent_name)[index] = dual_transformer
def remove_dual_attention(self):
@@ -107,7 +126,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
if isinstance(module, DualTransformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index] = module.image_transformer
self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
@@ -412,6 +431,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def set_mix_ratio(self, mix_ratio):
for name, module in self.image_unet.named_modules():
if isinstance(module, DualTransformer2DModel):
module.mix_ratio = mix_ratio
@torch.no_grad()
def __call__(
self,
@@ -539,6 +563,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# 7. Combine the attention blocks of the image and text UNets
self.convert_to_dual_attention(prompt_mix_ratio, prompt_types)
self.set_mix_ratio(prompt_mix_ratio)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
@@ -575,36 +600,3 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
return (image,)
return ImagePipelineOutput(images=image)
class DualTransformer2DModel(nn.Module):
def __init__(self, image_transformer, text_transformer, mix_ratio=0.5, condition_types=("text", "image")):
super().__init__()
self.image_transformer = image_transformer
self.text_transformer = text_transformer
self.mix_ratio = mix_ratio
self.condition_types = condition_types
def forward(self, input_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
if self.condition_types[0] == "text":
condition_states = [encoder_hidden_states[:, :77], encoder_hidden_states[:, 77:]]
else:
condition_states = [encoder_hidden_states[:, :257], encoder_hidden_states[:, 257:]]
encoded_states = []
for i in range(2):
if self.condition_types[i] == "text":
text_output = self.text_transformer(input_states, condition_states[i], timestep, return_dict)
encoded_states.append(text_output[0])
else:
image_output = self.image_transformer(input_states, condition_states[i], timestep, return_dict)
encoded_states.append(image_output[0])
encoded_states[i] = encoded_states[i] - input_states
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
@@ -34,6 +36,47 @@ class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittes
@slow
@require_torch_gpu
class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_from_pretrained_save_pretrained(self):
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(
first_prompt="first prompt",
second_prompt="second prompt",
prompt_mix_ratio=0.75,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
).images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = generator.manual_seed(0)
new_image = pipe(
first_prompt="first prompt",
second_prompt="second prompt",
prompt_mix_ratio=0.75,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
).images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
def test_inference_image_variations(self):
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
pipe.to(torch_device)