diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index a700d1db72..134b47215d 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.11 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index eae7eaf4fa..7714821c82 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.11 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d193ab0308..46e241d817 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,6 +496,8 @@ title: Bria 3.2 - local: api/pipelines/bria_fibo title: Bria Fibo + - local: api/pipelines/bria_fibo_edit + title: Bria Fibo Edit - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 diff --git a/docs/source/en/api/pipelines/bria_fibo_edit.md b/docs/source/en/api/pipelines/bria_fibo_edit.md new file mode 100644 index 0000000000..b46dd78cdb --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo_edit.md @@ -0,0 +1,33 @@ + + +# Bria Fibo Edit + +Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows. +Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments. +Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality + +## Usage +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaFiboEditPipeline + +[[autodoc]] BriaFiboEditPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5617e9b65d..24b9c12db6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -457,6 +457,7 @@ else: "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaFiboEditPipeline", "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", @@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaFiboEditPipeline, BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 16f1a5d1ec..6cdd66a0f2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -478,7 +478,7 @@ class PeftAdapterMixin: Args: adapter_names (`List[str]` or `str`): The names of the adapters to use. - adapter_weights (`Union[List[float], float]`, *optional*): + weights (`Union[List[float], float]`, *optional*): The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the adapters. @@ -495,7 +495,7 @@ class PeftAdapterMixin: "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c0b4ad4005..a98cb49114 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor.contiguous()) + # Only use contiguous() during training to avoid DDP gradient stride mismatch warning. + # In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU. + # Issue: https://github.com/huggingface/diffusers/issues/12975 + if self.training: + input_tensor = input_tensor.contiguous() + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 2696f5e787..3d38da1dfc 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed @@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel( PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Longcat-Image. """ _supports_gradient_checkpointing = True + _repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ca77ee9036..65378631a1 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,7 +129,7 @@ else: "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] - _import_structure["bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", @@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline - from .bria_fibo import BriaFiboPipeline + from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline from .chronoedit import ChronoEditPipeline from .cogvideo import ( diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py index 206a463b39..8dd7727090 100644 --- a/src/diffusers/pipelines/bria_fibo/__init__.py +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline else: import sys diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000..aae8fc7367 --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1133 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: Optional[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if generator is None and seed is not None: + generator = torch.Generator(device=device).manual_seed(seed) + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py index 019c144152..3ea1ece36c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py @@ -482,8 +482,6 @@ class ChromaInpaintPipeline( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, max_sequence_length=None, @@ -531,15 +529,6 @@ class ChromaInpaintPipeline( f" {negative_prompt_embeds.shape}." ) - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") @@ -793,13 +782,11 @@ class ChromaInpaintPipeline( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 94c4c39446..2ea7307fec 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index d259f7ee78..b41d9772a7 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """ >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index df5b3f5c10..5a6b8d5e9f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 66d5ffa6b8..a1d0407caf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch - >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py index f4bd0cc2d7..23c0e138c4 100644 --- a/src/diffusers/schedulers/scheduling_consistency_decoder.py +++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -14,7 +14,7 @@ from .scheduling_utils import SchedulerMixin def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -28,8 +28,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 74ade1d8bb..92c3e20013 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 92f7a5ab3a..1a77a65278 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -100,14 +100,13 @@ def betas_for_alpha_bar( return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): +def rescale_zero_terminal_snr(alphas_cumprod: torch.Tensor) -> torch.Tensor: """ - Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - + Rescales betas to have zero terminal SNR Based on (Algorithm 1)[https://huggingface.co/papers/2305.08891] Args: - betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + alphas_cumprod (`torch.Tensor`): + The alphas cumulative products that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR @@ -142,11 +141,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to 0.0120): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`str`, defaults to `"scaled_linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): @@ -179,6 +178,8 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + snr_shift_scale (`float`, defaults to 3.0): + Shift scale for SNR. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -190,15 +191,15 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): @@ -208,7 +209,15 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float64, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -238,7 +247,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -265,7 +274,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -317,7 +330,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -328,7 +341,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -487,5 +500,5 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 802d8f7977..476f741bcd 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -22,6 +22,7 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState @@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: @@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDIMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: @@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index e76ad9aa6c..a3c9ed1f62 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -49,7 +49,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -99,7 +99,7 @@ def betas_for_alpha_bar( # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, **kwargs, ): @@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - eta (`float`): - The weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`, defaults to `False`): - If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary - because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no - clipping has happened, "corrected" `model_output` would coincide with the one provided as input and - `use_clipped_model_output` has no effect. - variance_noise (`torch.Tensor`): - Alternative to generating noise with `generator` by directly providing the noise for the variance - itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`. @@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): # 1. get previous step value (=t+1) prev_timestep = timestep timestep = min( - timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1 + timestep - self.config.num_train_timesteps // self.num_inference_steps, + self.config.num_train_timesteps - 1, ) # 2. compute alphas, betas @@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 09f55ee4c2..76f0636fbf 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -51,7 +51,7 @@ class DDIMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -101,7 +101,7 @@ def betas_for_alpha_bar( # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ return sample - def _get_variance(self, timestep, prev_timestep=None): + def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor: if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return variance - def _batch_get_variance(self, t, prev_t): + def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) @@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: @@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped - predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when - `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would - coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This + correction is necessary because the predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches + the input and `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + Random number generator. variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://huggingface.co/papers/2210.05559) @@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): if variance_noise is None: variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) variance = std_dev_t * variance_noise @@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): + timesteps (`torch.Tensor`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d0596bb918..2e2816bbf3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -48,7 +48,7 @@ class DDPMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -192,7 +192,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -210,7 +215,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -268,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -337,7 +350,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -472,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -521,7 +544,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -620,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): def __len__(self) -> int: return self.config.num_train_timesteps - def previous_timestep(self, timestep: int) -> int: + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -629,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index a3264f54f5..e02b7ea0c0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState @@ -42,7 +46,12 @@ class DDPMSchedulerState: num_inference_steps: Optional[int] = None @classmethod - def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + def create( + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: @@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index ee7ab66be4..b02c5376f2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -50,7 +50,7 @@ class DDPMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -64,8 +64,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): For more details, see the original paper: https://huggingface.co/papers/2006.11239 Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + trained_betas (`np.ndarray`, *optional*): + Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`, defaults to `"fixed_small"`): + Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample for numerical stability. - clip_sample_range (`float`, default `1.0`): - the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + clip_sample (`bool`, defaults to `True`): + Option to clip predicted sample for numerical stability. + prediction_type (`str`, defaults to `"epsilon"`): + Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://huggingface.co/papers/2210.02303) - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method (introduced by Imagen, https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`. - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, default `"leading"`): + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, default `0`): + steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and @@ -202,7 +205,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -220,7 +228,15 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -280,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -350,7 +366,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -458,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): model_output: torch.Tensor, timestep: int, sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ @@ -470,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. - generator: random number generator. + generator (`torch.Generator`, *optional*): + Random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: @@ -483,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -532,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -555,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, ) -> torch.Tensor: """ @@ -568,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): - current discrete timesteps in the diffusion chain. This is now a list of integers. + timesteps (`torch.Tensor`): + Current discrete timesteps in the diffusion chain. This is a tensor of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -583,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass @@ -714,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep - def previous_timestep(self, timestep): + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -723,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index ebc3a33b27..7c2dfd8e50 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 66fb39c0bc..3e50ebbfe0 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -52,7 +52,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 990129f584..07cb64f32b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index a9c4fe57b6..2da90d287c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 5f9ce1393d..6f905a623d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -117,7 +117,7 @@ class BrownianTreeNoiseSampler: def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -131,8 +131,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e92f880e5b..e9bf815aba 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -50,8 +50,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 0258ea7777..11fec60c9c 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -51,7 +51,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4238c976e4..8b141325fb 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -54,7 +54,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -68,8 +68,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 011f97ba5c..0c5e28ad06 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -51,7 +51,7 @@ class HeunDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 37849e28b2..ee49ae67b9 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -52,7 +52,7 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 1c2791837c..6effb3699b 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -51,7 +51,7 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 66dedd5a6e..ada8806e8c 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -53,7 +53,7 @@ class LCMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -67,8 +67,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9fc9b1e64b..a1f9d27fd9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -49,7 +49,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e95a374457..0820f5baa8 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -28,7 +28,7 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -42,8 +42,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index fcebe7e21c..bec4a1bdf6 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -47,7 +47,7 @@ class RePaintSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -61,8 +61,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 7c679a255c..565fae1c0d 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -35,7 +35,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -49,8 +49,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 7a385f6291..a1303436cd 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -52,7 +52,7 @@ class TCDSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index bdc4feb0b1..14b09277da 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -48,7 +48,7 @@ class UnCLIPSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0536e8d1ed..d8e24d1964 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): time_shift_type: Literal["exponential"] = "exponential", sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, + shift_terminal: Optional[float] = None, ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + if shift_terminal is not None and not use_flow_sigmas: + raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self._begin_index = begin_index def set_timesteps( - self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None - ) -> None: + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. mu (`float`, *optional*): Optional mu parameter for dynamic shifting when using exponential time shift type. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None: + if not self.config.use_flow_sigmas: + raise ValueError( + "Passing `sigmas` is only supported when `use_flow_sigmas=True`. " + "Please set `use_flow_sigmas=True` during scheduler initialization." + ) + num_inference_steps = len(sigmas) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 - if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" - self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) @@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_exponential_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_beta_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_flow_sigmas: - alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1] + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + eps = 1e-6 + if np.fabs(sigmas[0] - 1) < eps: + # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update + sigmas[0] -= eps timesteps = (sigmas * self.config.num_train_timesteps).copy() if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] @@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d48981f380..63f381419f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaFiboPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/bria_fibo_edit/__init__.py b/tests/pipelines/bria_fibo_edit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000..5376c4b5e0 --- /dev/null +++ b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py @@ -0,0 +1,192 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboEditPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboEditPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + vae = AutoencoderKLWan( + base_dim=80, + decoder_base_dim=128, + dim_mult=[1, 2, 4, 4], + dropout=0.0, + in_channels=12, + latents_mean=[0.0] * 16, + latents_std=[1.0] * 16, + is_residual=True, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + inputs = { + "prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}', + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 192, + "width": 336, + "output_type": "np", + } + image = Image.new("RGB", (336, 192), (255, 255, 255)) + inputs["image"] = image + return inputs + + @unittest.skip(reason="will not be supported due to dim-fusion") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_num_images_per_prompt(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_consistent(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_single_identical(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {"edit_instruction": "a different prompt"} + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (64, 64), (32, 64)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_bria_fibo_edit_mask(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + output = pipe(**inputs).images[0] + + assert output.shape == (192, 336, 3) + + def test_bria_fibo_edit_mask_image_size_mismatch(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"): + pipe(**inputs) + + def test_bria_fibo_edit_mask_no_image(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L") + + # Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs) + inputs.pop("image", None) + inputs.update({"mask": mask}) + + with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"): + pipe(**inputs) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 8a693e9c2d..df1dd2d987 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @is_flaky() def test_model_cpu_offload_forward_pass(self): super().test_inference_batch_single_identical(expected_max_diff=8e-4) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index 503fdb242d..d3bfa4b308 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @slow @require_torch_accelerator