1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Cogvideox-5B Model adapter change (#9203)

* draft of embedding

---------

Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
zR
2024-08-23 10:03:29 +08:00
committed by GitHub
parent dc07fc29da
commit 960c149c77
8 changed files with 536 additions and 32 deletions

View File

@@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
There are two models available that can be used with the CogVideoX pipeline:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds.
### Memory optimization
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`

View File

@@ -86,6 +86,9 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(ckpt_path: str):
def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
transformer = CogVideoXTransformer3DModel(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
return transformer
def convert_vae(ckpt_path: str):
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -172,13 +185,26 @@ def get_args():
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args()
@@ -188,18 +214,33 @@ if __name__ == "__main__":
transformer = None
vae = None
if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
dtype,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
@@ -208,7 +249,7 @@ if __name__ == "__main__":
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
"timestep_spacing": "trailing",
}
)
@@ -218,5 +259,10 @@ if __name__ == "__main__":
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)

View File

@@ -1783,6 +1783,148 @@ class FluxAttnProcessor2_0:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.

View File

@@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the

View File

@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
return embeds
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -97,6 +98,7 @@ class CogVideoXBlock(nn.Module):
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
@@ -116,24 +118,24 @@ class CogVideoXBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
text_length = norm_encoder_hidden_states.size(1)
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
# them in cross-attention individually
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attn_output = self.attn1(
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -144,8 +146,9 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
@@ -231,6 +234,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -295,12 +299,113 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@@ -319,14 +424,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2)
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
@@ -344,6 +451,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -351,9 +459,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_final(hidden_states)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)

View File

@@ -23,6 +23,7 @@ from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
@@ -40,6 +41,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
@@ -55,6 +57,25 @@ EXAMPLE_DOC_STRING = """
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -409,6 +430,46 @@ class CogVideoXPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@@ -599,7 +660,14 @@ class CogVideoXPipeline(DiffusionPipeline):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -620,6 +688,7 @@ class CogVideoXPipeline(DiffusionPipeline):
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()

View File

@@ -30,7 +30,12 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
)
enable_full_determinism()
@@ -279,6 +284,44 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow
@require_torch_gpu