mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Refactor BriaFibo classes and update pipeline parameters
- Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents. - Modified the _unpack_latents method in BriaFiboPipeline to improve clarity. - Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching. - Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests.
This commit is contained in:
@@ -66,8 +66,8 @@ def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidde
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention-> BriaFiboAttention
|
||||
class BriaFiboAttnProcessor:
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
@@ -134,8 +134,8 @@ class BriaFiboAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention -> BriaFiboAttention
|
||||
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention
|
||||
_default_processor_cls = BriaFiboAttnProcessor
|
||||
_available_processors = [
|
||||
BriaFiboAttnProcessor,
|
||||
|
||||
@@ -353,8 +353,8 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
return self._interrupt
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
@@ -542,7 +542,8 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
|
||||
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -26,11 +25,10 @@ from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -45,6 +43,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -107,6 +106,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
}
|
||||
return inputs
|
||||
|
||||
@unittest.skip(reason="will not be supported due to dim-fusion")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@@ -137,62 +137,3 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_torch_accelerator
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
components[name] = module.to(torch_device).half()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for name, component in pipe_loaded.components.items():
|
||||
if name == "vae":
|
||||
continue
|
||||
if hasattr(component, "dtype"):
|
||||
self.assertTrue(
|
||||
component.dtype == torch.float16,
|
||||
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
|
||||
)
|
||||
|
||||
# def test_to_dtype(self):
|
||||
# components = self.get_dummy_components()
|
||||
# pipe = self.pipeline_class(**components)
|
||||
# pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
# self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
# pipe.to(dtype=torch.float16)
|
||||
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
# self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
@unittest.skip("")
|
||||
def test_save_load_dduf(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user