From f1b52327306d5c521258b624dfbb07779bfb420b Mon Sep 17 00:00:00 2001 From: galbria Date: Mon, 27 Oct 2025 15:54:29 +0000 Subject: [PATCH] Refactor BriaFibo classes and update pipeline parameters - Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents. - Modified the _unpack_latents method in BriaFiboPipeline to improve clarity. - Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching. - Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests. --- .../transformers/transformer_bria_fibo.py | 4 +- .../pipelines/bria_fibo/pipeline_bria_fibo.py | 5 +- .../bria_fibo/test_pipeline_bria_fibo.py | 65 +------------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index e1bfde9555..68a0765536 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -66,8 +66,8 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde return _get_projections(attn, hidden_states, encoder_hidden_states) +# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention class BriaFiboAttnProcessor: - # Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor _attention_backend = None _parallel_config = None @@ -134,8 +134,8 @@ class BriaFiboAttnProcessor: return hidden_states +# Copied from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): - # Copied from diffusers.models.transformers.transformer_flux.FluxAttention _default_processor_cls = BriaFiboAttnProcessor _available_processors = [ BriaFiboAttnProcessor, diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index ef86997155..8581939963 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -353,8 +353,8 @@ class BriaFiboPipeline(DiffusionPipeline): return self._interrupt @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor @@ -542,7 +542,8 @@ class BriaFiboPipeline(DiffusionPipeline): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. Examples: Returns: [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py index 969634f597..76b41114f8 100644 --- a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import unittest import numpy as np @@ -26,11 +25,10 @@ from diffusers import ( FlowMatchEulerDiscreteScheduler, ) from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel -from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np +from tests.pipelines.test_pipelines_common import PipelineTesterMixin from ...testing_utils import ( enable_full_determinism, - require_torch_accelerator, torch_device, ) @@ -45,6 +43,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = False test_group_offloading = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) @@ -107,6 +106,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): } return inputs + @unittest.skip(reason="will not be supported due to dim-fusion") def test_encode_prompt_works_in_isolation(self): pass @@ -137,62 +137,3 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) - - @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") - @require_torch_accelerator - def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components() - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.to(torch_device).half() - - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for name, component in pipe_loaded.components.items(): - if name == "vae": - continue - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess( - max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." - ) - - # def test_to_dtype(self): - # components = self.get_dummy_components() - # pipe = self.pipeline_class(**components) - # pipe.set_progress_bar_config(disable=None) - - # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - # self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - - # pipe.to(dtype=torch.float16) - # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - # self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) - - @unittest.skip("") - def test_save_load_dduf(self): - pass