From 491a933a1bf79d1f9cd3bc5903fc609ae6d6a9ac Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 8 Feb 2024 12:30:14 +0530 Subject: [PATCH] [I2VGenXL] `attention_head_dim` in the UNet (#6872) * attention_head_dim * debug * print more info * correct num_attention_heads behaviour * down_block_num_attention_heads -> num_attention_heads. * correct the image link in doc. * add: deprecation for num_attention_head * fix: test argument to use attention_head_dim * more fixes. * quality * address comments. * remove depcrecation. --- src/diffusers/models/attention.py | 1 + src/diffusers/models/unets/unet_i2vgen_xl.py | 12 +++++++++++- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 2 +- tests/pipelines/i2vgen_xl/test_i2vgenxl.py | 3 ++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index d4d611250a..f9d83afbd2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -158,6 +158,7 @@ class BasicTransformerBlock(nn.Module): super().__init__() self.only_cross_attention = only_cross_attention + # We keep these boolean flags for backward-compatibility. self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index de4acb7e0d..eb8c0b50a6 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -120,6 +120,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim. num_attention_heads (`int`, *optional*): The number of attention heads. """ @@ -147,10 +148,19 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): layers_per_block: int = 2, norm_num_groups: Optional[int] = 32, cross_attention_dim: int = 1024, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 64, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, ): super().__init__() + # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence + # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This + # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. + # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it + # without running proper depcrecation cycles for the {down,mid,up} blocks which are a + # part of the public API. + num_attention_heads = attention_head_dim + # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 5988957cb1..4f6ce85aaa 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -46,7 +46,7 @@ EXAMPLE_DOC_STRING = """ >>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16") >>> pipeline.enable_model_cpu_offload() - >>> image_url = "https://github.com/ali-vilab/i2vgen-xl/blob/main/data/test_images/img_0009.png?raw=true" + >>> image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" >>> image = load_image(image_url).convert("RGB") >>> prompt = "Papers were floating in the air on a table in the library" diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py index acd9f9140d..de8e2e3310 100644 --- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py +++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py @@ -80,7 +80,8 @@ class I2VGenXLPipelineFastTests(PipelineTesterMixin, unittest.TestCase): down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"), up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"), cross_attention_dim=4, - num_attention_heads=4, + attention_head_dim=4, + num_attention_heads=None, norm_num_groups=2, )