1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Refactor BriaFibo classes and update pipeline parameters

- Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents.
- Modified the _unpack_latents method in BriaFiboPipeline to improve clarity.
- Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching.
- Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests.
This commit is contained in:
galbria
2025-10-27 15:54:29 +00:00
parent a617433ace
commit f1b5232730
3 changed files with 8 additions and 66 deletions

View File

@@ -66,8 +66,8 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde
return _get_projections(attn, hidden_states, encoder_hidden_states)
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention
class BriaFiboAttnProcessor:
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor
_attention_backend = None
_parallel_config = None
@@ -134,8 +134,8 @@ class BriaFiboAttnProcessor:
return hidden_states
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [
BriaFiboAttnProcessor,

View File

@@ -353,8 +353,8 @@ class BriaFiboPipeline(DiffusionPipeline):
return self._interrupt
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
@@ -542,7 +542,8 @@ class BriaFiboPipeline(DiffusionPipeline):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -26,11 +25,10 @@ from diffusers import (
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
require_torch_accelerator,
torch_device,
)
@@ -45,6 +43,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
@@ -107,6 +106,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
@@ -137,62 +137,3 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_torch_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if name == "vae":
continue
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
# def test_to_dtype(self):
# components = self.get_dummy_components()
# pipe = self.pipeline_class(**components)
# pipe.set_progress_bar_config(disable=None)
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
# self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
# pipe.to(dtype=torch.float16)
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
# self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
@unittest.skip("")
def test_save_load_dduf(self):
pass