mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'hunyuanvideo15' into hunyuanvideo15-tests
This commit is contained in:
@@ -12,13 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# HunyuanVideo-1.5
|
||||
|
||||
@@ -59,6 +52,21 @@ video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=15)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- HunyuanVideo1.5 use attention masks with avariable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
|
||||
|
||||
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
|
||||
- **A100/A800/RTX 4090:** `flash` or `flash_varlen`
|
||||
- **Other GPUs:** `sage`
|
||||
|
||||
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
|
||||
|
||||
|
||||
```py
|
||||
pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend
|
||||
```
|
||||
|
||||
|
||||
## HunyuanVideo15Pipeline
|
||||
|
||||
|
||||
@@ -215,8 +215,6 @@ class HunyuanVideo15Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
||||
assert out_channels % factor == 0
|
||||
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_downsample = add_temporal_downsample
|
||||
@@ -531,7 +529,6 @@ class HunyuanVideo15Encoder3D(nn.Module):
|
||||
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
|
||||
batch_size, _, frame, height, width = hidden_states.shape
|
||||
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
|
||||
|
||||
@@ -546,7 +543,7 @@ class HunyuanVideo15Encoder3D(nn.Module):
|
||||
|
||||
class HunyuanVideo15Decoder3D(nn.Module):
|
||||
r"""
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -184,10 +184,7 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
|
||||
The dimension of the output embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -362,7 +359,7 @@ class HunyuanVideo15RotaryPosEmbed(nn.Module):
|
||||
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
||||
|
||||
axes_grids = []
|
||||
for i in range(3):
|
||||
for i in range(len(rope_sizes)):
|
||||
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
||||
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
||||
# differences in layerwise debugging outputs, but visually it is the same.
|
||||
|
||||
@@ -34,7 +34,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
|
||||
return crop_size_list
|
||||
|
||||
|
||||
# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
|
||||
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
|
||||
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
|
||||
"""
|
||||
Get the closest ratio in the buckets.
|
||||
|
||||
Reference in New Issue
Block a user