mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* add vae * Initial commit for Flux 2 Transformer implementation * add pipeline part * small edits to the pipeline and conversion * update conversion script * fix * up up * finish pipeline * Remove Flux IP Adapter logic for now * Remove deprecated 3D id logic * Remove ControlNet logic for now * Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block * update pipeline * Don't use biases for input projs and output AdaNorm * up * Remove bias for double stream block text QKV projections * Add script to convert Flux 2 transformer to diffusers * make style and make quality * fix a few things. * allow sft files to go. * fix image processor * fix batch * style a bit * Fix some bugs in Flux 2 transformer implementation * Fix dummy input preparation and fix some test bugs * fix dtype casting in timestep guidance module. * resolve conflicts., * remove ip adapter stuff. * Fix Flux 2 transformer consistency test * Fix bug in Flux2TransformerBlock (double stream block) * Get remaining Flux 2 transformer tests passing * make style; make quality; make fix-copies * remove stuff. * fix type annotaton. * remove unneeded stuff from tests * tests * up * up * add sf support * Remove unused IP Adapter and ControlNet logic from transformer (#9) * copied from * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: apolinário <joaopaulo.passos@gmail.com> * up * up * up * up * up * Refactor Flux2Attention into separate classes for double stream and single stream attention * Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion * Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False * Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion * Address review comments * Update src/diffusers/pipelines/flux2/pipeline_flux2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * up * Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12) * up * support ostris loras. (#13) * up * update schdule * up * up (#17) * add training scripts (#16) * add training scripts Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> * model cpu offload in validation. * add flux.2 readme * add img2img and tests * cpu offload in log validation * Apply suggestions from code review * fix * up * fixes * remove i2i training tests for now. --------- Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> Co-authored-by: linoytsaban <linoy@huggingface.co> * up --------- Co-authored-by: yiyixuxu <yixu310@gmail.com> Co-authored-by: Daniel Gu <dgu8957@gmail.com> Co-authored-by: yiyi@huggingface.co <yiyi@ip-10-53-87-203.ec2.internal> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: apolinário <joaopaulo.passos@gmail.com> Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> Co-authored-by: linoytsaban <linoy@huggingface.co>
163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from diffusers import Flux2Transformer2DModel, attention_backend
|
|
|
|
from ...testing_utils import enable_full_determinism, torch_device
|
|
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
|
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
model_class = Flux2Transformer2DModel
|
|
main_input_name = "hidden_states"
|
|
# We override the items here because the transformer under consideration is small.
|
|
model_split_percents = [0.7, 0.6, 0.6]
|
|
|
|
# Skip setting testing with default: AttnProcessor
|
|
uses_custom_attn_processor = True
|
|
|
|
@property
|
|
def dummy_input(self):
|
|
return self.prepare_dummy_input()
|
|
|
|
@property
|
|
def input_shape(self):
|
|
return (16, 4)
|
|
|
|
@property
|
|
def output_shape(self):
|
|
return (16, 4)
|
|
|
|
def prepare_dummy_input(self, height=4, width=4):
|
|
batch_size = 1
|
|
num_latent_channels = 4
|
|
sequence_length = 48
|
|
embedding_dim = 32
|
|
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
|
|
|
t_coords = torch.arange(1)
|
|
h_coords = torch.arange(height)
|
|
w_coords = torch.arange(width)
|
|
l_coords = torch.arange(1)
|
|
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4]
|
|
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
|
|
|
text_t_coords = torch.arange(1)
|
|
text_h_coords = torch.arange(1)
|
|
text_w_coords = torch.arange(1)
|
|
text_l_coords = torch.arange(sequence_length)
|
|
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
|
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
|
|
|
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
|
|
return {
|
|
"hidden_states": hidden_states,
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
"img_ids": image_ids,
|
|
"txt_ids": text_ids,
|
|
"timestep": timestep,
|
|
"guidance": guidance,
|
|
}
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
init_dict = {
|
|
"patch_size": 1,
|
|
"in_channels": 4,
|
|
"num_layers": 1,
|
|
"num_single_layers": 1,
|
|
"attention_head_dim": 16,
|
|
"num_attention_heads": 2,
|
|
"joint_attention_dim": 32,
|
|
"timestep_guidance_channels": 256, # Hardcoded in original code
|
|
"axes_dims_rope": [4, 4, 4, 4],
|
|
}
|
|
|
|
inputs_dict = self.dummy_input
|
|
return init_dict, inputs_dict
|
|
|
|
# TODO (Daniel, Sayak): We can remove this test.
|
|
def test_flux2_consistency(self, seed=0):
|
|
torch.manual_seed(seed)
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
torch.manual_seed(seed)
|
|
model = self.model_class(**init_dict)
|
|
# state_dict = model.state_dict()
|
|
# for key, param in state_dict.items():
|
|
# print(f"{key} | {param.shape}")
|
|
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with attention_backend("native"):
|
|
with torch.no_grad():
|
|
output = model(**inputs_dict)
|
|
|
|
if isinstance(output, dict):
|
|
output = output.to_tuple()[0]
|
|
|
|
self.assertIsNotNone(output)
|
|
|
|
# input & output have to have the same shape
|
|
input_tensor = inputs_dict[self.main_input_name]
|
|
expected_shape = input_tensor.shape
|
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
|
|
|
# Check against expected slice
|
|
# fmt: off
|
|
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
|
# fmt: on
|
|
|
|
flat_output = output.cpu().flatten()
|
|
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
|
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
|
|
|
def test_gradient_checkpointing_is_applied(self):
|
|
expected_set = {"Flux2Transformer2DModel"}
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
|
|
|
|
|
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
|
model_class = Flux2Transformer2DModel
|
|
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
|
|
|
def prepare_dummy_input(self, height, width):
|
|
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
|
|
|
|
|
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
|
model_class = Flux2Transformer2DModel
|
|
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
|
|
|
def prepare_dummy_input(self, height, width):
|
|
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|