mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add t2v + vae2.2
This commit is contained in:
@@ -34,6 +34,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1:, :, :]
|
||||
return x
|
||||
|
||||
class WanCausalConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
@@ -134,19 +232,23 @@ class WanResample(nn.Module):
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# default to dim //2
|
||||
if upsample_out_dim is None:
|
||||
upsample_out_dim = dim // 2
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||
)
|
||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
@@ -363,6 +465,48 @@ class WanMidBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualDownBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
num_res_blocks,
|
||||
temperal_downsample=False,
|
||||
down_flag=False):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
resnets = []
|
||||
for _ in range(num_res_blocks):
|
||||
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
self.downsampler = WanResample(out_dim, mode=mode)
|
||||
else:
|
||||
self.downsampler = None
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
if self.downsampler is not None:
|
||||
x = self.downsampler(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
class WanEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
@@ -380,6 +524,7 @@ class WanEncoder3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
@@ -388,6 +533,7 @@ class WanEncoder3d(nn.Module):
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -403,23 +549,35 @@ class WanEncoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
if is_residual:
|
||||
self.down_blocks.append(
|
||||
WanResidualDownBlock(
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
num_res_blocks,
|
||||
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
@@ -469,6 +627,92 @@ class WanEncoder3d(nn.Module):
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
class WanResidualUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
temperal_upsample (bool): Whether to upsample on temporal dimension
|
||||
up_flag (bool): Whether to upsample or not
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
temperal_upsample: bool = False,
|
||||
up_flag: bool = False,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# create residual blocks
|
||||
resnets = []
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
if up_flag:
|
||||
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
||||
else:
|
||||
self.upsampler = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
x_copy = x.clone()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsampler is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsampler(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
|
||||
if self.avg_shortcut is not None:
|
||||
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
||||
|
||||
return x
|
||||
|
||||
class WanUpBlock(nn.Module):
|
||||
"""
|
||||
@@ -513,7 +757,7 @@ class WanUpBlock(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
@@ -564,6 +808,8 @@ class WanDecoder3d(nn.Module):
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
out_channels: int = 3,
|
||||
is_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -577,7 +823,6 @@ class WanDecoder3d(nn.Module):
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
@@ -589,36 +834,47 @@ class WanDecoder3d(nn.Module):
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
if i > 0 and not is_residual:
|
||||
# wan vae 2.1
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
# determine if we need upsampling
|
||||
up_flag = i != len(dim_mult) - 1
|
||||
# determine upsampling mode, if not upsampling, set to None
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
if up_flag and temperal_upsample[i]:
|
||||
upsample_mode = "upsample3d"
|
||||
elif up_flag:
|
||||
upsample_mode = "upsample2d"
|
||||
# Create and add the upsampling block
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
if is_residual:
|
||||
up_block = WanResidualUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
||||
up_flag= up_flag,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
else:
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -637,7 +893,7 @@ class WanDecoder3d(nn.Module):
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -656,6 +912,44 @@ class WanDecoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# YiYi TODO: refactor this
|
||||
from einops import rearrange
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
@@ -671,6 +965,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
@@ -713,6 +1008,10 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
2.8251,
|
||||
1.9160,
|
||||
],
|
||||
is_residual: bool = False,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
patch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -720,14 +1019,17 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
if decoder_base_dim is None:
|
||||
decoder_base_dim = base_dim
|
||||
|
||||
self.encoder = WanEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
|
||||
)
|
||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = WanDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
@@ -827,6 +1129,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
@@ -884,12 +1188,14 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(num_frame):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.patch_size is not None:
|
||||
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
@@ -112,10 +112,21 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage.
|
||||
If provided, enables two-stage denoising where `transformer` handles high-noise stages
|
||||
and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`.
|
||||
When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep.
|
||||
If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,6 +135,8 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -133,8 +146,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
@@ -270,6 +284,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -301,6 +316,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -369,6 +387,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -407,6 +426,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None,
|
||||
uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -461,6 +483,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -470,7 +493,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -524,34 +551,49 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
#with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
#with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
Reference in New Issue
Block a user