1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'hunyuanvideo15' into hunyuanvideo15-tests

This commit is contained in:
Sayak Paul
2025-12-01 06:09:48 +05:30
committed by GitHub
4 changed files with 19 additions and 17 deletions

View File

@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# HunyuanVideo-1.5
@@ -59,6 +52,21 @@ video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
export_to_video(video, "output.mp4", fps=15)
```
## Notes
- HunyuanVideo1.5 use attention masks with avariable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
- **A100/A800/RTX 4090:** `flash` or `flash_varlen`
- **Other GPUs:** `sage`
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
```py
pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend
```
## HunyuanVideo15Pipeline

View File

@@ -215,8 +215,6 @@ class HunyuanVideo15Downsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
@@ -531,7 +529,6 @@ class HunyuanVideo15Encoder3D(nn.Module):
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
@@ -546,7 +543,7 @@ class HunyuanVideo15Encoder3D(nn.Module):
class HunyuanVideo15Decoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
"""
def __init__(

View File

@@ -184,10 +184,7 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
The dimension of the output embedding.
"""
def __init__(
self,
embedding_dim: int,
):
def __init__(self, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -362,7 +359,7 @@ class HunyuanVideo15RotaryPosEmbed(nn.Module):
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
axes_grids = []
for i in range(3):
for i in range(len(rope_sizes)):
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
# original implementation creates it on CPU and then moves it to device. This results in numerical
# differences in layerwise debugging outputs, but visually it is the same.

View File

@@ -34,7 +34,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
return crop_size_list
# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
"""
Get the closest ratio in the buckets.