diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md
index 2d74088355..9a9bdcb352 100644
--- a/docs/source/en/api/pipelines/hunyuan_video15.md
+++ b/docs/source/en/api/pipelines/hunyuan_video15.md
@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
-
# HunyuanVideo-1.5
@@ -59,6 +52,21 @@ video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
export_to_video(video, "output.mp4", fps=15)
```
+## Notes
+
+- HunyuanVideo1.5 use attention masks with avariable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
+
+ - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
+ - **A100/A800/RTX 4090:** `flash` or `flash_varlen`
+ - **Other GPUs:** `sage`
+
+Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
+
+
+```py
+pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend
+```
+
## HunyuanVideo15Pipeline
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
index 7d6a636a24..4b1beb74a3 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
@@ -215,8 +215,6 @@ class HunyuanVideo15Downsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
- assert out_channels % factor == 0
- # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
@@ -531,7 +529,6 @@ class HunyuanVideo15Encoder3D(nn.Module):
hidden_states = self.mid_block(hidden_states)
- # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
@@ -546,7 +543,7 @@ class HunyuanVideo15Encoder3D(nn.Module):
class HunyuanVideo15Decoder3D(nn.Module):
r"""
- Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
+ Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
"""
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py
index 8f191e7500..b870b15dad 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py
@@ -184,10 +184,7 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
The dimension of the output embedding.
"""
- def __init__(
- self,
- embedding_dim: int,
- ):
+ def __init__(self, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -362,7 +359,7 @@ class HunyuanVideo15RotaryPosEmbed(nn.Module):
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
axes_grids = []
- for i in range(3):
+ for i in range(len(rope_sizes)):
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
# original implementation creates it on CPU and then moves it to device. This results in numerical
# differences in layerwise debugging outputs, but visually it is the same.
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
index b0c7782352..82817365b6 100644
--- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
+++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
@@ -34,7 +34,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
return crop_size_list
-# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
+# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
"""
Get the closest ratio in the buckets.