From d7a1c31f4f85bae5a9e01cdce49bd7346bd8ccd6 Mon Sep 17 00:00:00 2001 From: Gal Davidi <158810732+galbria@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:39:53 +0200 Subject: [PATCH] Fibo edit pipeline (#12930) * Feature: Add BriaFiboEditPipeline to diffusers * Introduced BriaFiboEditPipeline class with necessary backend requirements. * Updated import structures in relevant modules to include BriaFiboEditPipeline. * Ensured compatibility with existing pipelines and type checking. * Feature: Introduce Bria Fibo Edit Pipeline * Added BriaFiboEditPipeline class for structured JSON-native image editing. * Created documentation for the new pipeline in bria_fibo_edit.md. * Updated import structures to include the new pipeline and its components. * Added unit tests for the BriaFiboEditPipeline to ensure functionality and correctness. * Enhancement: Update Bria Fibo Edit Pipeline and Documentation * Refined the Bria Fibo Edit model description for clarity and detail. * Added usage instructions for model authentication and login. * Implemented mask handling functions in the BriaFiboEditPipeline for improved image editing capabilities. * Updated unit tests to cover new mask functionalities and ensure input validation. * Adjusted example code in documentation to reflect changes in the pipeline's usage. * Update Bria Fibo Edit documentation with corrected Hugging Face page link * add dreambooth training script * style and quality * Delete temp.py * Enhancement: Improve JSON caption validation in DreamBoothDataset * Updated the clean_json_caption function to handle both string and dictionary inputs for captions. * Added error handling to raise a ValueError for invalid caption types, ensuring better input validation. * Add datasets dependency to requirements_fibo_edit.txt * Add bria_fibo_edit to docs table of contents * Fix dummy objects ordering * Fix BriaFiboEditPipeline to use passed generator parameter The pipeline was ignoring the generator parameter and only using the seed parameter. This caused non-deterministic outputs in tests that pass a seeded generator. * Remove fibo_edit training script and related files --------- Co-authored-by: kfirbria --- docs/source/en/_toctree.yml | 2 + .../source/en/api/pipelines/bria_fibo_edit.md | 33 + src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/bria_fibo/__init__.py | 3 + .../bria_fibo/pipeline_bria_fibo_edit.py | 1133 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/bria_fibo_edit/__init__.py | 0 .../test_pipeline_bria_fibo_edit.py | 192 +++ 9 files changed, 1382 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/pipelines/bria_fibo_edit.md create mode 100644 src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py create mode 100644 tests/pipelines/bria_fibo_edit/__init__.py create mode 100644 tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d193ab0308..46e241d817 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,6 +496,8 @@ title: Bria 3.2 - local: api/pipelines/bria_fibo title: Bria Fibo + - local: api/pipelines/bria_fibo_edit + title: Bria Fibo Edit - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 diff --git a/docs/source/en/api/pipelines/bria_fibo_edit.md b/docs/source/en/api/pipelines/bria_fibo_edit.md new file mode 100644 index 0000000000..b46dd78cdb --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo_edit.md @@ -0,0 +1,33 @@ + + +# Bria Fibo Edit + +Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows. +Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments. +Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality + +## Usage +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaFiboEditPipeline + +[[autodoc]] BriaFiboEditPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5617e9b65d..24b9c12db6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -457,6 +457,7 @@ else: "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaFiboEditPipeline", "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", @@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaFiboEditPipeline, BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ca77ee9036..65378631a1 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,7 +129,7 @@ else: "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] - _import_structure["bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", @@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline - from .bria_fibo import BriaFiboPipeline + from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline from .chronoedit import ChronoEditPipeline from .cogvideo import ( diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py index 206a463b39..8dd7727090 100644 --- a/src/diffusers/pipelines/bria_fibo/__init__.py +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline else: import sys diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000..aae8fc7367 --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1133 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: Optional[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if generator is None and seed is not None: + generator = torch.Generator(device=device).manual_seed(seed) + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d48981f380..63f381419f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaFiboPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/bria_fibo_edit/__init__.py b/tests/pipelines/bria_fibo_edit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000..5376c4b5e0 --- /dev/null +++ b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py @@ -0,0 +1,192 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboEditPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboEditPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + vae = AutoencoderKLWan( + base_dim=80, + decoder_base_dim=128, + dim_mult=[1, 2, 4, 4], + dropout=0.0, + in_channels=12, + latents_mean=[0.0] * 16, + latents_std=[1.0] * 16, + is_residual=True, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + inputs = { + "prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}', + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 192, + "width": 336, + "output_type": "np", + } + image = Image.new("RGB", (336, 192), (255, 255, 255)) + inputs["image"] = image + return inputs + + @unittest.skip(reason="will not be supported due to dim-fusion") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_num_images_per_prompt(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_consistent(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_single_identical(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {"edit_instruction": "a different prompt"} + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (64, 64), (32, 64)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_bria_fibo_edit_mask(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + output = pipe(**inputs).images[0] + + assert output.shape == (192, 336, 3) + + def test_bria_fibo_edit_mask_image_size_mismatch(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"): + pipe(**inputs) + + def test_bria_fibo_edit_mask_no_image(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L") + + # Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs) + inputs.pop("image", None) + inputs.update({"mask": mask}) + + with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"): + pipe(**inputs)