mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Sana Sprint] add image-to-image pipeline (#11602)
* sana sprint img2img * fix import * fix name * fix image encoding * fix image encoding * fix image encoding * fix image encoding * fix image encoding * fix image encoding * try w/o strength * try scaling differently * try with strength * revert unnecessary changes to scheduler * revert unnecessary changes to scheduler * Apply style fixes * remove comment * add copy statements * add copy statements * add to doc * add to doc * add to doc * add to doc * Apply style fixes * empty commit * fix copies * fix copies * fix copies * fix copies * fix copies * docs * make fix-copies. * fix doc building error. * initial commit - add img2img test * initial commit - add img2img test * fix import * fix imports * Apply style fixes * empty commit * remove * empty commit * test vocab size * fix * fix prompt missing from last commits * small changes * fix image processing when input is tensor * fix order * Apply style fixes * empty commit * fix shape * remove comment * image processing * remove comment * skip vae tiling test for now * Apply style fixes * empty commit --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -88,12 +88,46 @@ image.save("sana.png")
|
||||
|
||||
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
|
||||
|
||||
## Image to Image
|
||||
|
||||
The [`SanaSprintImg2ImgPipeline`] is a pipeline for image-to-image generation. It takes an input image and a prompt, and generates a new image based on the input image and the prompt.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import SanaSprintImg2ImgPipeline
|
||||
from diffusers.utils.loading_utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
||||
)
|
||||
|
||||
pipe = SanaSprintImg2ImgPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
||||
torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt="a cute pink bear",
|
||||
image=image,
|
||||
strength=0.5,
|
||||
height=832,
|
||||
width=480
|
||||
).images[0]
|
||||
image[0].save("output.png")
|
||||
```
|
||||
|
||||
## SanaSprintPipeline
|
||||
|
||||
[[autodoc]] SanaSprintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## SanaSprintImg2ImgPipeline
|
||||
|
||||
[[autodoc]] SanaSprintImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaPipelineOutput
|
||||
|
||||
|
||||
@@ -441,6 +441,7 @@ else:
|
||||
"SanaControlNetPipeline",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
@@ -1025,6 +1026,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaControlNetPipeline,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
|
||||
@@ -290,7 +290,12 @@ else:
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
||||
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline", "SanaControlNetPipeline"]
|
||||
_import_structure["sana"] = [
|
||||
"SanaPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_audio"] = [
|
||||
@@ -675,7 +680,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -25,6 +25,7 @@ else:
|
||||
_import_structure["pipeline_sana"] = ["SanaPipeline"]
|
||||
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
|
||||
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
|
||||
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -37,6 +38,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_sana import SanaPipeline
|
||||
from .pipeline_sana_controlnet import SanaControlNetPipeline
|
||||
from .pipeline_sana_sprint import SanaSprintPipeline
|
||||
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
975
src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
Normal file
975
src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
Normal file
@@ -0,0 +1,975 @@
|
||||
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput, PixArtImageProcessor
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, SanaTransformer2DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
USE_PEFT_BACKEND,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
|
||||
from .pipeline_output import SanaPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import SanaSprintImg2ImgPipeline
|
||||
>>> from diffusers.utils.loading_utils import load_image
|
||||
|
||||
>>> pipe = SanaSprintImg2ImgPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
||||
... )
|
||||
|
||||
|
||||
>>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0]
|
||||
>>> image[0].save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
|
||||
# fmt: on
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
|
||||
if hasattr(self, "vae") and self.vae is not None
|
||||
else 32
|
||||
)
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana_sprint.SanaSprintPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = max_sequence_length
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
max_timesteps,
|
||||
intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
||||
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
||||
|
||||
if timesteps is not None and max_timesteps is not None:
|
||||
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
||||
|
||||
if timesteps is None and max_timesteps is None:
|
||||
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
||||
|
||||
if intermediate_timesteps is not None and num_inference_steps != 2:
|
||||
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
# Resize if current dimensions do not match target dimensions.
|
||||
if image.shape[2] != height or image.shape[3] != width:
|
||||
image = F.interpolate(image, size=(height, width), mode="bilinear", align_corners=False)
|
||||
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
return image
|
||||
|
||||
def prepare_latents(
|
||||
self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if image.shape[1] != num_channels_latents:
|
||||
image = self.vae.encode(image).latent
|
||||
image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
# adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data
|
||||
latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 2,
|
||||
timesteps: List[int] = None,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: float = 1.3,
|
||||
guidance_scale: float = 4.5,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.6,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clean_caption: bool = False,
|
||||
use_resolution_binning: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: List[str] = [
|
||||
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
|
||||
"Here are examples of how to transform or refine prompts:",
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
|
||||
"User Prompt: ",
|
||||
],
|
||||
) -> Union[SanaPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 20):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
max_timesteps (`float`, *optional*, defaults to 1.57080):
|
||||
The maximum timestep value used in the SCM scheduler.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs:
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
use_resolution_binning (`bool` defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||
the requested resolution. Useful for generating non-square images.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `300`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
complex_human_instruction (`List[str]`, *optional*):
|
||||
Instructions for complex human attention:
|
||||
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
if use_resolution_binning:
|
||||
if self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
max_timesteps=max_timesteps,
|
||||
intermediate_timesteps=intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
|
||||
# 2. Preprocess image
|
||||
init_image = self.prepare_image(image, width, height, device, self.vae.dtype)
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=None,
|
||||
max_timesteps=max_timesteps,
|
||||
intermediate_timesteps=intermediate_timesteps,
|
||||
)
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1]
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# I think this is redundant given the scaling in prepare_latents
|
||||
# latents = latents * self.scheduler.config.sigma_data
|
||||
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
guidance = guidance * self.transformer.config.guidance_embeds_scale
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
timesteps = timesteps[:-1]
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
latents_model_input = latents / self.scheduler.config.sigma_data
|
||||
|
||||
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
|
||||
|
||||
scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
|
||||
latent_model_input = latents_model_input * torch.sqrt(
|
||||
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
|
||||
)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
guidance=guidance,
|
||||
timestep=scm_timestep,
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
noise_pred = (
|
||||
(1 - 2 * scm_timestep_expanded) * latent_model_input
|
||||
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
|
||||
) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
|
||||
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents, denoised = self.scheduler.step(
|
||||
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
latents = denoised / self.scheduler.config.sigma_data
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
try:
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
warnings.warn(
|
||||
f"{e}. \n"
|
||||
f"Try to use VAE tiling for large images. For example: \n"
|
||||
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
|
||||
)
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return SanaPipelineOutput(images=image)
|
||||
@@ -1622,6 +1622,21 @@ class SanaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SanaSprintImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SanaSprintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
313
tests/pipelines/sana/test_sana_sprint_img2img.py
Normal file
313
tests/pipelines/sana/test_sana_sprint_img2img.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = SanaSprintImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
|
||||
"negative_prompt",
|
||||
"negative_prompt_embeds",
|
||||
}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"negative_prompt"}
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SanaTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=4,
|
||||
num_cross_attention_heads=2,
|
||||
cross_attention_head_dim=4,
|
||||
cross_attention_dim=8,
|
||||
caption_channels=8,
|
||||
sample_size=32,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
guidance_embeds=True,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderDC(
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
attention_head_dim=2,
|
||||
encoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
decoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
encoder_block_out_channels=(8, 8),
|
||||
decoder_block_out_channels=(8, 8),
|
||||
encoder_qkv_multiscales=((), (5,)),
|
||||
decoder_qkv_multiscales=((), (5,)),
|
||||
encoder_layers_per_block=(1, 1),
|
||||
decoder_layers_per_block=[1, 1],
|
||||
downsample_block_type="conv",
|
||||
upsample_block_type="interpolate",
|
||||
decoder_norm_types="rms_norm",
|
||||
decoder_act_fns="silu",
|
||||
scaling_factor=0.41407,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = SCMScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = Gemma2Config(
|
||||
head_dim=16,
|
||||
hidden_size=8,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=64,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma2",
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=1,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=8,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
text_encoder = Gemma2Model(text_encoder_config)
|
||||
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = torch.randn(1, 3, 32, 32, generator=generator)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"image": image,
|
||||
"strength": 0.5,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"complex_human_instruction": None,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.randn(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
@unittest.skip("vae tiling resulted in a small margin over the expected max diff, so skipping this test for now")
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
# TODO(aryan): Create a dummy gemma model with smol vocab size
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
def test_float16_inference(self):
|
||||
# Requires higher tolerance as model seems very sensitive to dtype
|
||||
super().test_float16_inference(expected_max_diff=0.08)
|
||||
Reference in New Issue
Block a user