1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Sayak Paul
2025-12-01 01:39:05 +00:00
parent 5f9a215178
commit 5efb5efbf2

View File

@@ -15,13 +15,13 @@
import unittest
import torch
from transformers import CLIPTokenizer, T5Config, T5EncoderModel
from transformers import T5EncoderModel, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, ByT5Tokenizer
from diffusers import (
AutoencoderKLHunyuanVideo15,
FlowMatchEulerDiscreteScheduler,
HunyuanVideo15Pipeline,
HunyuanVideo15Transformer3DModel,
HunyuanVideo15Transformer3DModel
)
from diffusers.guiders import ClassifierFreeGuidance
@@ -41,7 +41,6 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"negative_prompt",
"height",
"width",
"num_frames",
"prompt_embeds",
"prompt_embeds_mask",
"negative_prompt_embeds",
@@ -50,11 +49,19 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prompt_embeds_mask_2",
"negative_prompt_embeds_2",
"negative_prompt_embeds_mask_2",
"output_type",
]
)
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
callback_cfg_params = frozenset()
batch_params = ["prompt", "negative_prompt"]
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
test_layerwise_casting = True
@@ -64,7 +71,7 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideo15Transformer3DModel(
in_channels=4,
in_channels=9,
out_channels=4,
num_attention_heads=2,
attention_head_dim=8,
@@ -74,9 +81,9 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
patch_size=1,
patch_size_t=1,
text_embed_dim=16,
text_embed_2_dim=8,
text_embed_2_dim=32,
image_embed_dim=12,
rope_axes_dim=(2, 4, 4),
rope_axes_dim=(2, 2, 4),
target_size=16,
task_type="t2v",
)
@@ -86,62 +93,49 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(16, 16, 16, 16),
block_out_channels=(16, 16),
layers_per_block=1,
spatial_compression_ratio=4,
temporal_compression_ratio=2,
scaling_factor=0.476986,
downsample_match_channel=False,
upsample_match_channel=False,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
torch.manual_seed(0)
main_text_config = T5Config(
d_model=16,
d_kv=4,
d_ff=64,
num_layers=2,
num_heads=4,
relative_attention_num_buckets=8,
relative_attention_max_distance=32,
vocab_size=64,
feed_forward_proj="gated-gelu",
dense_act_fn="gelu_new",
is_encoder_decoder=False,
use_cache=False,
tie_word_embeddings=False,
qwen_config = Qwen2_5_VLTextConfig(
**{
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
}
)
text_encoder = T5EncoderModel(main_text_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder = Qwen2_5_VLTextModel(qwen_config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
torch.manual_seed(0)
secondary_text_config = T5Config(
d_model=8,
d_kv=4,
d_ff=32,
num_layers=2,
num_heads=2,
relative_attention_num_buckets=8,
relative_attention_max_distance=32,
vocab_size=32,
feed_forward_proj="gated-gelu",
dense_act_fn="gelu_new",
is_encoder_decoder=False,
use_cache=False,
tie_word_embeddings=False,
)
text_encoder_2 = T5EncoderModel(secondary_text_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_2 = ByT5Tokenizer()
guider = ClassifierFreeGuidance(guidance_scale=1.0)
components = {
"transformer": transformer,
"vae": vae,
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"text_encoder": text_encoder.eval(),
"text_encoder_2": text_encoder_2.eval(),
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"guider": guider,
@@ -154,29 +148,14 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
torch.manual_seed(seed)
batch_size = 1
seq_len = 4
seq_len_2 = 3
text_embed_dim = 16
text_embed_2_dim = 8
prompt_embeds = torch.randn((batch_size, seq_len, text_embed_dim), device=device)
prompt_embeds_mask = torch.ones((batch_size, seq_len), device=device)
prompt_embeds_2 = torch.randn((batch_size, seq_len_2, text_embed_2_dim), device=device)
prompt_embeds_mask_2 = torch.ones((batch_size, seq_len_2), device=device)
inputs = {
"prompt": "monkey",
"generator": generator,
"num_inference_steps": 2,
"num_frames": 5,
"height": 16,
"width": 16,
"num_frames": 9,
"output_type": "pt",
"prompt_embeds": prompt_embeds,
"prompt_embeds_mask": prompt_embeds_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_embeds_mask_2": prompt_embeds_mask_2,
}
return inputs
@@ -193,5 +172,20 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
video = result.frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (inputs["num_frames"], 3, inputs["height"], inputs["width"]))
self.assertFalse(torch.isnan(generated_video).any())
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
# fmt: off
expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878])
# fmt: on
self.assertTrue(
torch.abs(generated_slice - expected_slice).max() < 1e-3,
f"output_slice: {generated_slice}, expected_slice: {expected_slice}",
)
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
def test_encode_prompt_works_in_isolation(self):
pass