1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

text unet end to end

This commit is contained in:
anton-l
2022-11-21 23:56:35 +01:00
parent 8c989ebe40
commit f706729d3c
13 changed files with 728 additions and 100 deletions

View File

@@ -31,13 +31,14 @@ from diffusers import (
UNet2DConditionModel,
VersatileDiffusionPipeline,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from diffusers.pipelines.versatile_diffusion.modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetMultiDimConditionModel
SCHEDULER_CONFIG = Namespace(
@@ -241,7 +242,7 @@ def assign_to_checkpoint(
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
elif path["old"] in old_checkpoint:
checkpoint[new_path] = old_checkpoint[path["old"]]
@@ -306,14 +307,14 @@ def create_text_unet_diffusers_config(unet_params):
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlockMultiDim" if unet_params.with_attn[i] else "DownBlockMultiDim"
block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlockMultiDim" if unet_params.with_attn[-i - 1] else "UpBlockMultiDim"
block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat"
up_block_types.append(block_type)
resolution //= 2
@@ -322,8 +323,8 @@ def create_text_unet_diffusers_config(unet_params):
config = dict(
sample_size=None,
in_channels=unet_params.input_channels,
out_channels=unet_params.output_channels,
in_channels=(unet_params.input_channels, 1, 1),
out_channels=(unet_params.output_channels, 1, 1),
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
@@ -450,6 +451,17 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
elif f"input_blocks.{i}.0.weight" in unet_state_dict:
# text_unet uses linear layers in place of downsamplers
shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
@@ -512,10 +524,34 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.1.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.1.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.1.bias"
)
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.2.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.2.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.2.bias"
)
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -727,7 +763,7 @@ if __name__ == "__main__":
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
)
text_unet = UNetMultiDimConditionModel(**text_unet_config)
text_unet = UNetFlatConditionModel(**text_unet_config)
text_unet.load_state_dict(converted_text_unet_checkpoint)
# Convert the VAE model.

View File

@@ -76,6 +76,7 @@ if is_torch_available() and is_transformers_available():
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VersatileDiffusionImageToTextPipeline,
VQDiffusionPipeline,
)
else:

View File

@@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,

View File

@@ -201,7 +201,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:

View File

@@ -28,6 +28,7 @@ if is_torch_available() and is_transformers_available():
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VersatileDiffusionImageToTextPipeline,
)
from .vq_diffusion import VQDiffusionPipeline

View File

@@ -3,6 +3,9 @@ from ...utils import is_torch_available, is_transformers_available
if is_transformers_available() and is_torch_available():
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline

View File

@@ -6,8 +6,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin
from ...models.embeddings import TimestepEmbedding, Timesteps
from ...models.attention import Transformer2DModel
from ...models.embeddings import TimestepEmbedding, Timesteps
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging
@@ -15,7 +15,7 @@ from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block_multi_dim(
def get_down_block(
down_block_type,
num_layers,
in_channels,
@@ -23,38 +23,45 @@ def get_down_block_multi_dim(
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockMultiDim":
return DownBlockMultiDim(
if down_block_type == "DownBlockFlat":
return DownBlockFlat(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "CrossAttnDownBlockMultiDim":
elif down_block_type == "CrossAttnDownBlockFlat":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMultiDim")
return CrossAttnDownBlockMultiDim(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")
return CrossAttnDownBlockFlat(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{down_block_type} is not supported.")
def get_up_block_multi_dim(
def get_up_block(
up_block_type,
num_layers,
in_channels,
@@ -63,13 +70,14 @@ def get_up_block_multi_dim(
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockMultiDim":
return UpBlockMultiDim(
if up_block_type == "UpBlockFlat":
return UpBlockFlat(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -77,12 +85,13 @@ def get_up_block_multi_dim(
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "CrossAttnUpBlockMultiDim":
elif up_block_type == "CrossAttnUpBlockFlat":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMultiDim")
return CrossAttnUpBlockMultiDim(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")
return CrossAttnUpBlockFlat(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -90,6 +99,7 @@ def get_up_block_multi_dim(
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
@@ -97,11 +107,11 @@ def get_up_block_multi_dim(
raise ValueError(f"{up_block_type} is not supported.")
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
timestep and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
@@ -114,9 +124,9 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
@@ -142,19 +152,18 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockMultiDim",
"CrossAttnDownBlockMultiDim",
"CrossAttnDownBlockMultiDim",
"DownBlockMultiDim",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
up_block_types: Tuple[str] = (
"UpBlockMultiDim",
"CrossAttnUpBlockMultiDim",
"CrossAttnUpBlockMultiDim",
"CrossAttnUpBlockMultiDim",
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_second_dim: Tuple[int] = (4, 4, 4, 4),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -170,7 +179,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = LinearMultiDim([in_channels, 1, 1], [block_out_channels[0], block_second_dim[0], 1])
self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
@@ -187,25 +196,26 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
second_dim = block_second_dim[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block_multi_dim(
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=[output_channel, second_dim, 1],
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockMultiDimCrossAttn(
self.mid_block = UNetMidBlockFlatCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
@@ -237,7 +247,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
else:
add_upsample = False
up_block = get_up_block_multi_dim(
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
@@ -246,6 +256,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
@@ -256,7 +267,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = LinearMultiDim(block_out_channels[0], [out_channels, 1, 1])
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
@@ -292,9 +303,7 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(
module, (CrossAttnDownBlockMultiDim, DownBlockMultiDim, CrossAttnUpBlockMultiDim, UpBlockMultiDim)
):
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
module.gradient_checkpointing = value
def forward(
@@ -308,7 +317,8 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
encoder_hidden_states (`torch.FloatTensor`):
(batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
@@ -410,23 +420,25 @@ class UNetMultiDimConditionModel(ModelMixin, ConfigMixin):
class LinearMultiDim(nn.Linear):
def __init__(self, in_features, out_features, second_dim=4, *args, **kwargs):
def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):
in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)
if out_features is None:
out_features = in_features
out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)
self.in_features_multidim = in_features
self.out_features_multidim = out_features
super().__init__(np.array(in_features).prod(), np.array(out_features).prod())
def forward(self, x):
shape = x.shape
n = len(self.in_features_multidim)
x = x.view(*shape[0:-n], self.in_features)
y = super().forward(x)
y = y.view(*shape[0:-n], *self.out_features_multidim)
return y
def forward(self, input_tensor, *args, **kwargs):
shape = input_tensor.shape
n_dim = len(self.in_features_multidim)
input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)
output_tensor = super().forward(input_tensor)
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)
return output_tensor
class ResnetBlockMultiDim(nn.Module):
class ResnetBlockFlat(nn.Module):
def __init__(
self,
*,
@@ -440,29 +452,31 @@ class ResnetBlockMultiDim(nn.Module):
eps=1e-6,
time_embedding_norm="default",
use_in_shortcut=None,
second_dim=4,
**kwargs,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
in_channels = [in_channels] if isinstance(in_channels, int) else list(in_channels)
in_channels_prod = np.array(in_channels).prod()
in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)
self.in_channels_prod = np.array(in_channels).prod()
self.channels_multidim = in_channels
if out_channels is not None:
out_channels = [out_channels] if isinstance(out_channels, int) else list(out_channels)
out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)
out_channels_prod = np.array(out_channels).prod()
self.out_channels_multidim = out_channels
else:
out_channels_prod = in_channels_prod
out_channels_prod = self.in_channels_prod
self.out_channels_multidim = self.channels_multidim
self.time_embedding_norm = time_embedding_norm
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels_prod, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)
@@ -475,15 +489,20 @@ class ResnetBlockMultiDim(nn.Module):
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
self.use_in_shortcut = self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
)
def forward(self, input_tensor, temb):
shape = input_tensor.shape
n_dim = len(self.channels_multidim)
input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)
input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
@@ -505,10 +524,15 @@ class ResnetBlockMultiDim(nn.Module):
output_tensor = input_tensor + hidden_states
output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
print("resblock.output_tensor", output_tensor.abs().sum())
return output_tensor
class DownBlockMultiDim(nn.Module):
# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
in_channels: int,
@@ -518,9 +542,12 @@ class DownBlockMultiDim(nn.Module):
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
@@ -528,7 +555,7 @@ class DownBlockMultiDim(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -536,6 +563,8 @@ class DownBlockMultiDim(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
@@ -543,7 +572,13 @@ class DownBlockMultiDim(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)])
self.downsamplers = nn.ModuleList(
[
LinearMultiDim(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
@@ -576,7 +611,8 @@ class DownBlockMultiDim(nn.Module):
return hidden_states, output_states
class CrossAttnDownBlockMultiDim(nn.Module):
# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,
in_channels: int,
@@ -586,11 +622,14 @@ class CrossAttnDownBlockMultiDim(nn.Module):
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
@@ -603,7 +642,7 @@ class CrossAttnDownBlockMultiDim(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -611,14 +650,16 @@ class CrossAttnDownBlockMultiDim(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels[0] // attn_num_head_channels,
in_channels=out_channels[0],
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -628,7 +669,13 @@ class CrossAttnDownBlockMultiDim(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)])
self.downsamplers = nn.ModuleList(
[
LinearMultiDim(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
@@ -687,7 +734,8 @@ class CrossAttnDownBlockMultiDim(nn.Module):
return hidden_states, output_states
class UpBlockMultiDim(nn.Module):
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class UpBlockFlat(nn.Module):
def __init__(
self,
in_channels: int,
@@ -698,8 +746,10 @@ class UpBlockMultiDim(nn.Module):
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
@@ -710,7 +760,7 @@ class UpBlockMultiDim(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -718,6 +768,8 @@ class UpBlockMultiDim(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
@@ -725,7 +777,7 @@ class UpBlockMultiDim(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)])
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
@@ -757,7 +809,8 @@ class UpBlockMultiDim(nn.Module):
return hidden_states
class CrossAttnUpBlockMultiDim(nn.Module):
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class CrossAttnUpBlockFlat(nn.Module):
def __init__(
self,
in_channels: int,
@@ -768,11 +821,13 @@ class CrossAttnUpBlockMultiDim(nn.Module):
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
@@ -787,7 +842,7 @@ class CrossAttnUpBlockMultiDim(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -795,6 +850,8 @@ class CrossAttnUpBlockMultiDim(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
@@ -812,7 +869,7 @@ class CrossAttnUpBlockMultiDim(nn.Module):
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, out_channels)])
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
@@ -879,7 +936,8 @@ class CrossAttnUpBlockMultiDim(nn.Module):
return hidden_states
class UNetMidBlockMultiDimCrossAttn(nn.Module):
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
@@ -888,10 +946,12 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module):
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
**kwargs,
):
@@ -903,7 +963,7 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module):
# there is always at least one resnet
resnets = [
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -911,6 +971,8 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
@@ -928,7 +990,7 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module):
)
)
resnets.append(
ResnetBlockMultiDim(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -936,6 +998,8 @@ class UNetMidBlockMultiDimCrossAttn(nn.Module):
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)

View File

@@ -9,7 +9,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from . import VersatileDiffusionImageVariationPipeline, VersatileDiffusionTextToImagePipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -0,0 +1,461 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint
import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, GPT2Tokenizer
from .modeling_text_unet import UNetFlatConditionModel
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, BaseOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class TextPipelineOutput(BaseOutput):
"""
Output class for text generation pipelines.
Args:
text (`List[str]` or `np.ndarray`)
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size, num_tokens)`.
"""
text: Union[List[str], np.ndarray]
class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
bert ([`LDMBertModel`]):
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
tokenizer (`transformers.BertTokenizer`):
Tokenizer of class
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
image_feature_extractor: CLIPFeatureExtractor
image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
def __init__(
self,
image_feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
image_feature_extractor=image_feature_extractor,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
self.text_vae_decoder = GPT2OptimusForLatentConnector.from_pretrained("fusing/gpt2_optimus")
self.text_vae_tokenizer = GPT2Tokenizer.from_pretrained("fusing/gpt2_optimus")
def swap_unet_attention_blocks(self):
for name, module in self.image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
self.text_unet.get_submodule(parent_name)[index],
self.image_unet.get_submodule(parent_name)[index],
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
return self.device
for module in self.image_unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
def normalize_embeddings(encoder_output):
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
embeds = self.image_encoder.visual_projection(embeds)
embeds_pooled = embeds[:, 0:1]
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
image_embeddings = self.image_encoder(image_input.pixel_values.to(self.device))
image_embeddings = normalize_embeddings(image_embeddings)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_images: List[str]
if negative_prompt is None:
uncond_images = [np.zeros((512, 512, 3))] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, PIL.Image.Image):
uncond_images = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_images = negative_prompt
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
uncond_embeddings = self.image_encoder(uncond_images.pixel_values.to(self.device))
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = latents.reshape(latents.shape[:-2])
self.text_vae_decoder = self.text_vae_decoder.to(self._execution_device)
bos_token = self.text_vae_tokenizer.bos_token_id
output = self.text_vae_decoder.generate(bos_token_id=bos_token, past=latents)
return output
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, callback_steps):
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, 1, 1)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "str",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
image_embeddings = self._encode_prompt(
image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.text_unet.in_channels[0]
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Swap the attention blocks between the image and text UNets
self.swap_unet_attention_blocks()
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.text_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Swap the attention blocks backs in case the UNets are reused in another pipeline
self.swap_unet_attention_blocks()
# 10. Post-processing
text = self.decode_latents(latents)
# 11. Convert to strings
if output_type == "str":
text = self.text_vae_tokenizer.decode(text)
if not return_dict:
return (text,)
return TextPipelineOutput(text=text)

View File

@@ -22,8 +22,7 @@ import torch.utils.checkpoint
import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
from ...models.attention import Transformer2DModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
@@ -73,16 +72,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
scheduler=scheduler,
)
def swap_unet_attention_blocks(self):
for name, module in self.image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
self.text_unet.get_submodule(parent_name)[index],
self.image_unet.get_submodule(parent_name)[index],
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""

View File

@@ -20,7 +20,8 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
from .modeling_text_unet import UNetFlatConditionModel
from ...models import UNet2DConditionModel, AutoencoderKL
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -52,7 +53,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_feature_extractor: CLIPFeatureExtractor
text_encoder: CLIPTextModelWithProjection
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
@@ -61,8 +62,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel,
vae: Union[VQModel, AutoencoderKL],
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()

View File

@@ -139,6 +139,21 @@ class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class VersatileDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,56 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import VersatileDiffusionImageToTextPipeline
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass
@slow
@require_torch_gpu
class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
def test_inference_image_to_text(self):
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("scripts/vd_official")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_prompt = load_image(
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
tokens = pipe(
image=image_prompt,
generator=generator,
guidance_scale=7.5,
num_inference_steps=50,
output_type="numpy",
).text
assert tokens.shape == (1, 30)
expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7])
assert self.assertItemsEqual(tokens[0] , expected_tokens)