mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* bugfix for npu not support float64 * is_mps is_npu --------- Co-authored-by: 白超 <baichao19@huawei.com> Co-authored-by: hlky <hlky@hlky.ac>
387 lines
19 KiB
Python
387 lines
19 KiB
Python
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from diffusers.configuration_utils import register_to_config
|
|
from diffusers.models.controlnet import (
|
|
ControlNetConditioningEmbedding,
|
|
ControlNetModel,
|
|
ControlNetOutput,
|
|
)
|
|
from diffusers.utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class PromptDiffusionControlNetModel(ControlNetModel):
|
|
"""
|
|
A PromptDiffusionControlNet model.
|
|
|
|
Args:
|
|
in_channels (`int`, defaults to 4):
|
|
The number of channels in the input sample.
|
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
|
Whether to flip the sin to cos in the time embedding.
|
|
freq_shift (`int`, defaults to 0):
|
|
The frequency shift to apply to the time embedding.
|
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
|
The tuple of downsample blocks to use.
|
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
|
The tuple of output channels for each block.
|
|
layers_per_block (`int`, defaults to 2):
|
|
The number of layers per block.
|
|
downsample_padding (`int`, defaults to 1):
|
|
The padding to use for the downsampling convolution.
|
|
mid_block_scale_factor (`float`, defaults to 1):
|
|
The scale factor to use for the mid block.
|
|
act_fn (`str`, defaults to "silu"):
|
|
The activation function to use.
|
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
|
in post-processing.
|
|
norm_eps (`float`, defaults to 1e-5):
|
|
The epsilon to use for the normalization.
|
|
cross_attention_dim (`int`, defaults to 1280):
|
|
The dimension of the cross attention features.
|
|
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
|
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
|
dimension to `cross_attention_dim`.
|
|
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
|
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
|
The dimension of the attention heads.
|
|
use_linear_projection (`bool`, defaults to `False`):
|
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
|
addition_embed_type (`str`, *optional*, defaults to `None`):
|
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
|
class conditioning with `class_embed_type` equal to `None`.
|
|
upcast_attention (`bool`, defaults to `False`):
|
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
|
`class_embed_type="projection"`.
|
|
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
|
global_pool_conditions (`bool`, defaults to `False`):
|
|
TODO(Patrick) - unused parameter.
|
|
addition_embed_type_num_heads (`int`, defaults to 64):
|
|
The number of heads to use for the `TextTimeEmbedding` layer.
|
|
"""
|
|
|
|
_supports_gradient_checkpointing = True
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 4,
|
|
conditioning_channels: int = 3,
|
|
flip_sin_to_cos: bool = True,
|
|
freq_shift: int = 0,
|
|
down_block_types: Tuple[str, ...] = (
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"DownBlock2D",
|
|
),
|
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
|
layers_per_block: int = 2,
|
|
downsample_padding: int = 1,
|
|
mid_block_scale_factor: float = 1,
|
|
act_fn: str = "silu",
|
|
norm_num_groups: Optional[int] = 32,
|
|
norm_eps: float = 1e-5,
|
|
cross_attention_dim: int = 1280,
|
|
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
|
encoder_hid_dim: Optional[int] = None,
|
|
encoder_hid_dim_type: Optional[str] = None,
|
|
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
|
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
use_linear_projection: bool = False,
|
|
class_embed_type: Optional[str] = None,
|
|
addition_embed_type: Optional[str] = None,
|
|
addition_time_embed_dim: Optional[int] = None,
|
|
num_class_embeds: Optional[int] = None,
|
|
upcast_attention: bool = False,
|
|
resnet_time_scale_shift: str = "default",
|
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
|
controlnet_conditioning_channel_order: str = "rgb",
|
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
|
global_pool_conditions: bool = False,
|
|
addition_embed_type_num_heads: int = 64,
|
|
):
|
|
super().__init__(
|
|
in_channels,
|
|
conditioning_channels,
|
|
flip_sin_to_cos,
|
|
freq_shift,
|
|
down_block_types,
|
|
mid_block_type,
|
|
only_cross_attention,
|
|
block_out_channels,
|
|
layers_per_block,
|
|
downsample_padding,
|
|
mid_block_scale_factor,
|
|
act_fn,
|
|
norm_num_groups,
|
|
norm_eps,
|
|
cross_attention_dim,
|
|
transformer_layers_per_block,
|
|
encoder_hid_dim,
|
|
encoder_hid_dim_type,
|
|
attention_head_dim,
|
|
num_attention_heads,
|
|
use_linear_projection,
|
|
class_embed_type,
|
|
addition_embed_type,
|
|
addition_time_embed_dim,
|
|
num_class_embeds,
|
|
upcast_attention,
|
|
resnet_time_scale_shift,
|
|
projection_class_embeddings_input_dim,
|
|
controlnet_conditioning_channel_order,
|
|
conditioning_embedding_out_channels,
|
|
global_pool_conditions,
|
|
addition_embed_type_num_heads,
|
|
)
|
|
self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding(
|
|
conditioning_embedding_channels=block_out_channels[0],
|
|
block_out_channels=conditioning_embedding_out_channels,
|
|
conditioning_channels=3,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.Tensor,
|
|
timestep: Union[torch.Tensor, float, int],
|
|
encoder_hidden_states: torch.Tensor,
|
|
controlnet_cond: torch.Tensor,
|
|
controlnet_query_cond: torch.Tensor,
|
|
conditioning_scale: float = 1.0,
|
|
class_labels: Optional[torch.Tensor] = None,
|
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
guess_mode: bool = False,
|
|
return_dict: bool = True,
|
|
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
|
"""
|
|
The [`~PromptDiffusionControlNetModel`] forward method.
|
|
|
|
Args:
|
|
sample (`torch.Tensor`):
|
|
The noisy input tensor.
|
|
timestep (`Union[torch.Tensor, float, int]`):
|
|
The number of timesteps to denoise an input.
|
|
encoder_hidden_states (`torch.Tensor`):
|
|
The encoder hidden states.
|
|
controlnet_cond (`torch.Tensor`):
|
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
|
controlnet_query_cond (`torch.Tensor`):
|
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
|
conditioning_scale (`float`, defaults to `1.0`):
|
|
The scale factor for ControlNet outputs.
|
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
|
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
|
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
|
embeddings.
|
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
|
negative values to the attention scores corresponding to "discard" tokens.
|
|
added_cond_kwargs (`dict`):
|
|
Additional conditions for the Stable Diffusion XL UNet.
|
|
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
|
guess_mode (`bool`, defaults to `False`):
|
|
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
|
return_dict (`bool`, defaults to `True`):
|
|
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
|
|
|
|
Returns:
|
|
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
|
|
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
|
returned where the first element is the sample tensor.
|
|
"""
|
|
# check channel order
|
|
channel_order = self.config.controlnet_conditioning_channel_order
|
|
|
|
if channel_order == "rgb":
|
|
# in rgb order by default
|
|
...
|
|
elif channel_order == "bgr":
|
|
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
|
else:
|
|
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
|
|
|
# prepare attention_mask
|
|
if attention_mask is not None:
|
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# 1. time
|
|
timesteps = timestep
|
|
if not torch.is_tensor(timesteps):
|
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
|
# This would be a good case for the `match` statement (Python 3.10+)
|
|
is_mps = sample.device.type == "mps"
|
|
is_npu = sample.device.type == "npu"
|
|
if isinstance(timestep, float):
|
|
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
|
else:
|
|
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
|
elif len(timesteps.shape) == 0:
|
|
timesteps = timesteps[None].to(sample.device)
|
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
timesteps = timesteps.expand(sample.shape[0])
|
|
|
|
t_emb = self.time_proj(timesteps)
|
|
|
|
# timesteps does not contain any weights and will always return f32 tensors
|
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
# there might be better ways to encapsulate this.
|
|
t_emb = t_emb.to(dtype=sample.dtype)
|
|
|
|
emb = self.time_embedding(t_emb, timestep_cond)
|
|
aug_emb = None
|
|
|
|
if self.class_embedding is not None:
|
|
if class_labels is None:
|
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
|
|
|
if self.config.class_embed_type == "timestep":
|
|
class_labels = self.time_proj(class_labels)
|
|
|
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
|
emb = emb + class_emb
|
|
|
|
if self.config.addition_embed_type is not None:
|
|
if self.config.addition_embed_type == "text":
|
|
aug_emb = self.add_embedding(encoder_hidden_states)
|
|
|
|
elif self.config.addition_embed_type == "text_time":
|
|
if "text_embeds" not in added_cond_kwargs:
|
|
raise ValueError(
|
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
|
)
|
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
|
if "time_ids" not in added_cond_kwargs:
|
|
raise ValueError(
|
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
|
)
|
|
time_ids = added_cond_kwargs.get("time_ids")
|
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
|
|
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
|
add_embeds = add_embeds.to(emb.dtype)
|
|
aug_emb = self.add_embedding(add_embeds)
|
|
|
|
emb = emb + aug_emb if aug_emb is not None else emb
|
|
|
|
# 2. pre-process
|
|
sample = self.conv_in(sample)
|
|
|
|
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
|
controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond)
|
|
sample = sample + controlnet_cond + controlnet_query_cond
|
|
|
|
# 3. down
|
|
down_block_res_samples = (sample,)
|
|
for downsample_block in self.down_blocks:
|
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
|
sample, res_samples = downsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
)
|
|
else:
|
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
|
|
|
down_block_res_samples += res_samples
|
|
|
|
# 4. mid
|
|
if self.mid_block is not None:
|
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
|
sample = self.mid_block(
|
|
sample,
|
|
emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
)
|
|
else:
|
|
sample = self.mid_block(sample, emb)
|
|
|
|
# 5. Control net blocks
|
|
|
|
controlnet_down_block_res_samples = ()
|
|
|
|
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
|
down_block_res_sample = controlnet_block(down_block_res_sample)
|
|
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
|
|
|
down_block_res_samples = controlnet_down_block_res_samples
|
|
|
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
|
|
|
# 6. scaling
|
|
if guess_mode and not self.config.global_pool_conditions:
|
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
|
scales = scales * conditioning_scale
|
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
|
else:
|
|
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
|
|
|
if self.config.global_pool_conditions:
|
|
down_block_res_samples = [
|
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
|
]
|
|
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
|
|
|
if not return_dict:
|
|
return (down_block_res_samples, mid_block_res_sample)
|
|
|
|
return ControlNetOutput(
|
|
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
|
)
|