mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'add_versatile_diffusers' of https://github.com/huggingface/diffusers into add_versatile_diffusers
This commit is contained in:
@@ -22,7 +22,7 @@ from torch import nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import CONFIG_NAME, BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -666,3 +666,120 @@ class AdaLayerNorm(nn.Module):
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, num_condition_tokens[0]+num_condition_tokens[1], num_features)`
|
||||
self.num_condition_tokens = (77, 257)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||
tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.num_condition_tokens[i]]
|
||||
encoded_state = self.transformers[i](input_states, condition_state, timestep, return_dict)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.num_condition_tokens[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for transformer in self.transformers:
|
||||
transformer._set_attention_slice(slice_size)
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for transformer in self.transformers:
|
||||
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, Transformer2DModel
|
||||
from .attention import AttentionBlock, Transformer2DModel, DualTransformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import PIL
|
||||
@@ -29,7 +28,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention import Transformer2DModel, Transformer2DModelOutput
|
||||
from ...models.attention import DualTransformer2DModel, Transformer2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
@@ -94,12 +93,32 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
index = int(index)
|
||||
|
||||
image_transformer = self.image_unet.get_submodule(parent_name)[index]
|
||||
text_transformer = self.text_unet.get_submodule(parent_name)[index]
|
||||
|
||||
config = image_transformer.config
|
||||
dual_transformer = DualTransformer2DModel(
|
||||
image_transformer, text_transformer, mix_ratio=mix_ratio, condition_types=condition_types
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
attention_head_dim=config.attention_head_dim,
|
||||
in_channels=config.in_channels,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout,
|
||||
norm_num_groups=config.norm_num_groups,
|
||||
cross_attention_dim=config.cross_attention_dim,
|
||||
attention_bias=config.attention_bias,
|
||||
sample_size=config.sample_size,
|
||||
num_vector_embeds=config.num_vector_embeds,
|
||||
activation_fn=config.activation_fn,
|
||||
num_embeds_ada_norm=config.num_embeds_ada_norm,
|
||||
)
|
||||
for i, type in enumerate(condition_types):
|
||||
if type == "image":
|
||||
dual_transformer.transformers[i] = image_transformer
|
||||
else:
|
||||
dual_transformer.transformers[i] = text_transformer
|
||||
|
||||
dual_transformer.mix_ratio = mix_ratio
|
||||
self.image_unet.get_submodule(parent_name)[index] = dual_transformer
|
||||
|
||||
def remove_dual_attention(self):
|
||||
@@ -107,7 +126,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
if isinstance(module, DualTransformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
index = int(index)
|
||||
self.image_unet.get_submodule(parent_name)[index] = module.image_transformer
|
||||
self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
@@ -412,6 +431,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def set_mix_ratio(self, mix_ratio):
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, DualTransformer2DModel):
|
||||
module.mix_ratio = mix_ratio
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -539,6 +563,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
|
||||
# 7. Combine the attention blocks of the image and text UNets
|
||||
self.convert_to_dual_attention(prompt_mix_ratio, prompt_types)
|
||||
self.set_mix_ratio(prompt_mix_ratio)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@@ -575,36 +600,3 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
def __init__(self, image_transformer, text_transformer, mix_ratio=0.5, condition_types=("text", "image")):
|
||||
super().__init__()
|
||||
self.image_transformer = image_transformer
|
||||
self.text_transformer = text_transformer
|
||||
self.mix_ratio = mix_ratio
|
||||
self.condition_types = condition_types
|
||||
|
||||
def forward(self, input_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
||||
if self.condition_types[0] == "text":
|
||||
condition_states = [encoder_hidden_states[:, :77], encoder_hidden_states[:, 77:]]
|
||||
else:
|
||||
condition_states = [encoder_hidden_states[:, :257], encoder_hidden_states[:, 257:]]
|
||||
|
||||
encoded_states = []
|
||||
for i in range(2):
|
||||
if self.condition_types[i] == "text":
|
||||
text_output = self.text_transformer(input_states, condition_states[i], timestep, return_dict)
|
||||
encoded_states.append(text_output[0])
|
||||
else:
|
||||
image_output = self.image_transformer(input_states, condition_states[i], timestep, return_dict)
|
||||
encoded_states.append(image_output[0])
|
||||
encoded_states[i] = encoded_states[i] - input_states
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -34,6 +36,47 @@ class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittes
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
first_prompt="first prompt",
|
||||
second_prompt="second prompt",
|
||||
prompt_mix_ratio=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = pipe(
|
||||
first_prompt="first prompt",
|
||||
second_prompt="second prompt",
|
||||
prompt_mix_ratio=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_image_variations(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
pipe.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user