mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* update * update * Revert "update" This reverts commit73906381ab. * Revert "update" This reverts commit21a03f93ef. * update * update * update * update * update
2264 lines
105 KiB
Python
2264 lines
105 KiB
Python
# Copyright 2025 Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab Team
|
|
# and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import inspect
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
|
|
|
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
|
from diffusers.loaders import (
|
|
FromSingleFileMixin,
|
|
IPAdapterMixin,
|
|
PeftAdapterMixin,
|
|
StableDiffusionXLLoraLoaderMixin,
|
|
TextualInversionLoaderMixin,
|
|
UNet2DConditionLoadersMixin,
|
|
)
|
|
from diffusers.models import AutoencoderKL
|
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
|
from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
|
from diffusers.utils import (
|
|
USE_PEFT_BACKEND,
|
|
deprecate,
|
|
is_invisible_watermark_available,
|
|
is_torch_version,
|
|
is_torch_xla_available,
|
|
logging,
|
|
replace_example_docstring,
|
|
scale_lora_layers,
|
|
unscale_lora_layers,
|
|
)
|
|
from diffusers.utils.outputs import BaseOutput
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
|
if is_invisible_watermark_available():
|
|
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
EXAMPLE_DOC_STRING = """
|
|
Examples:
|
|
```py
|
|
>>> import random
|
|
>>> import numpy as np
|
|
>>> import torch
|
|
>>> from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
|
|
>>> from huggingface_hub import hf_hub_download
|
|
>>> from diffusers.utils import load_image
|
|
>>> from PIL import Image
|
|
>>>
|
|
>>> device = "cuda"
|
|
>>> dtype = torch.float16
|
|
>>> MAX_SEED = np.iinfo(np.int32).max
|
|
>>>
|
|
>>> # Download weights for additional unet layers
|
|
>>> model_file = hf_hub_download(
|
|
... "jychen9811/FaithDiff",
|
|
... filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
|
|
... )
|
|
>>>
|
|
>>> # Initialize the models and pipeline
|
|
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
|
|
>>>
|
|
>>> model_id = "SG161222/RealVisXL_V4.0"
|
|
>>> pipe = DiffusionPipeline.from_pretrained(
|
|
... model_id,
|
|
... torch_dtype=dtype,
|
|
... vae=vae,
|
|
... unet=None, #<- Do not load with original model.
|
|
... custom_pipeline="mixture_tiling_sdxl",
|
|
... use_safetensors=True,
|
|
... variant="fp16",
|
|
... ).to(device)
|
|
>>>
|
|
>>> # Here we need use pipeline internal unet model
|
|
>>> pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
|
|
>>>
|
|
>>> # Load additional layers to the model
|
|
>>> pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
|
|
>>>
|
|
>>> # Enable vae tiling
|
|
>>> pipe.set_encoder_tile_settings()
|
|
>>> pipe.enable_vae_tiling()
|
|
>>>
|
|
>>> # Optimization
|
|
>>> pipe.enable_model_cpu_offload()
|
|
>>>
|
|
>>> # Set selected scheduler
|
|
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
>>>
|
|
>>> #input params
|
|
>>> prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
|
|
>>> upscale = 2 # scale here
|
|
>>> start_point = "lr" # or "noise"
|
|
>>> latent_tiled_overlap = 0.5
|
|
>>> latent_tiled_size = 1024
|
|
>>>
|
|
>>> # Load image
|
|
>>> lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
|
|
>>> original_height = lq_image.height
|
|
>>> original_width = lq_image.width
|
|
>>> print(f"Current resolution: H:{original_height} x W:{original_width}")
|
|
>>>
|
|
>>> width = original_width * int(upscale)
|
|
>>> height = original_height * int(upscale)
|
|
>>> print(f"Final resolution: H:{height} x W:{width}")
|
|
>>>
|
|
>>> # Restoration
|
|
>>> image = lq_image.resize((width, height), Image.LANCZOS)
|
|
>>> input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
|
|
>>>
|
|
>>> generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
|
|
>>> gen_image = pipe(lr_img=input_image,
|
|
... prompt = prompt,
|
|
... num_inference_steps=20,
|
|
... guidance_scale=5,
|
|
... generator=generator,
|
|
... start_point=start_point,
|
|
... height = height_now,
|
|
... width=width_now,
|
|
... overlap=latent_tiled_overlap,
|
|
... target_size=(latent_tiled_size, latent_tiled_size)
|
|
... ).images[0]
|
|
>>>
|
|
>>> cropped_image = gen_image.crop((0, 0, width_init, height_init))
|
|
>>> cropped_image.save("data/result.png")
|
|
```
|
|
"""
|
|
|
|
|
|
def zero_module(module):
|
|
"""Zero out the parameters of a module and return it."""
|
|
for p in module.parameters():
|
|
nn.init.zeros_(p)
|
|
return module
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""Encoder layer of a variational autoencoder that encodes input into a latent representation."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 4,
|
|
down_block_types: Tuple[str, ...] = (
|
|
"DownEncoderBlock2D",
|
|
"DownEncoderBlock2D",
|
|
"DownEncoderBlock2D",
|
|
"DownEncoderBlock2D",
|
|
),
|
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
|
layers_per_block: int = 2,
|
|
norm_num_groups: int = 32,
|
|
act_fn: str = "silu",
|
|
double_z: bool = True,
|
|
mid_block_add_attention: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
|
|
self.conv_in = nn.Conv2d(
|
|
in_channels,
|
|
block_out_channels[0],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
|
|
self.mid_block = None
|
|
self.down_blocks = nn.ModuleList([])
|
|
self.use_rgb = False
|
|
self.down_block_type = down_block_types
|
|
self.block_out_channels = block_out_channels
|
|
|
|
self.tile_sample_min_size = 1024
|
|
self.tile_latent_min_size = int(self.tile_sample_min_size / 8)
|
|
self.tile_overlap_factor = 0.25
|
|
self.use_tiling = False
|
|
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
down_block = get_down_block(
|
|
down_block_type,
|
|
num_layers=self.layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=not is_final_block,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attention_head_dim=output_channel,
|
|
temb_channels=None,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
self.mid_block = UNetMidBlock2D(
|
|
in_channels=block_out_channels[-1],
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=1,
|
|
resnet_time_scale_shift="default",
|
|
attention_head_dim=block_out_channels[-1],
|
|
resnet_groups=norm_num_groups,
|
|
temb_channels=None,
|
|
add_attention=mid_block_add_attention,
|
|
)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
def to_rgb_init(self):
|
|
"""Initialize layers to convert features to RGB."""
|
|
self.to_rgbs = nn.ModuleList([])
|
|
self.use_rgb = True
|
|
for i, down_block_type in enumerate(self.down_block_type):
|
|
output_channel = self.block_out_channels[i]
|
|
self.to_rgbs.append(nn.Conv2d(output_channel, 3, kernel_size=3, padding=1))
|
|
|
|
def enable_tiling(self):
|
|
"""Enable tiling for large inputs."""
|
|
self.use_tiling = True
|
|
|
|
def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""Encode the input tensor into a latent representation."""
|
|
sample = self.conv_in(sample)
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
if is_torch_version(">=", "1.11.0"):
|
|
for down_block in self.down_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(down_block), sample, use_reentrant=False
|
|
)
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
|
)
|
|
else:
|
|
for down_block in self.down_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
|
return sample
|
|
else:
|
|
for down_block in self.down_blocks:
|
|
sample = down_block(sample)
|
|
sample = self.mid_block(sample)
|
|
return sample
|
|
|
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
"""Blend two tensors vertically with a smooth transition."""
|
|
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
|
for y in range(blend_extent):
|
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
|
return b
|
|
|
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
"""Blend two tensors horizontally with a smooth transition."""
|
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
|
for x in range(blend_extent):
|
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
|
return b
|
|
|
|
def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""Encode the input tensor using tiling for large inputs."""
|
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
|
rows = []
|
|
for i in range(0, x.shape[2], overlap_size):
|
|
row = []
|
|
for j in range(0, x.shape[3], overlap_size):
|
|
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
|
tile = self.encode(tile)
|
|
row.append(tile)
|
|
rows.append(row)
|
|
result_rows = []
|
|
for i, row in enumerate(rows):
|
|
result_row = []
|
|
for j, tile in enumerate(row):
|
|
if i > 0:
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
if j > 0:
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
result_rows.append(torch.cat(result_row, dim=3))
|
|
|
|
moments = torch.cat(result_rows, dim=2)
|
|
return moments
|
|
|
|
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""Forward pass of the encoder, using tiling if enabled for large inputs."""
|
|
if self.use_tiling and (
|
|
sample.shape[-1] > self.tile_latent_min_size or sample.shape[-2] > self.tile_latent_min_size
|
|
):
|
|
return self.tiled_encode(sample)
|
|
return self.encode(sample)
|
|
|
|
|
|
class ControlNetConditioningEmbedding(nn.Module):
|
|
"""A small network to preprocess conditioning inputs, inspired by ControlNet."""
|
|
|
|
def __init__(self, conditioning_embedding_channels: int, conditioning_channels: int = 4):
|
|
super().__init__()
|
|
self.conv_in = nn.Conv2d(conditioning_channels, conditioning_channels, kernel_size=3, padding=1)
|
|
self.norm_in = nn.GroupNorm(num_channels=conditioning_channels, num_groups=32, eps=1e-6)
|
|
self.conv_out = zero_module(
|
|
nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=3, padding=1)
|
|
)
|
|
|
|
def forward(self, conditioning):
|
|
"""Process the conditioning input through the network."""
|
|
conditioning = self.norm_in(conditioning)
|
|
embedding = self.conv_in(conditioning)
|
|
embedding = F.silu(embedding)
|
|
embedding = self.conv_out(embedding)
|
|
return embedding
|
|
|
|
|
|
class QuickGELU(nn.Module):
|
|
"""A fast approximation of the GELU activation function."""
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Apply the QuickGELU activation to the input tensor."""
|
|
return x * torch.sigmoid(1.702 * x)
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm):
|
|
"""Subclass torch's LayerNorm to handle fp16."""
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Apply LayerNorm and preserve the input dtype."""
|
|
orig_type = x.dtype
|
|
ret = super().forward(x)
|
|
return ret.type(orig_type)
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
"""A transformer-style block with self-attention and an MLP."""
|
|
|
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
|
super().__init__()
|
|
self.attn = nn.MultiheadAttention(d_model, n_head)
|
|
self.ln_1 = LayerNorm(d_model)
|
|
self.mlp = nn.Sequential(
|
|
OrderedDict(
|
|
[
|
|
("c_fc", nn.Linear(d_model, d_model * 2)),
|
|
("gelu", QuickGELU()),
|
|
("c_proj", nn.Linear(d_model * 2, d_model)),
|
|
]
|
|
)
|
|
)
|
|
self.ln_2 = LayerNorm(d_model)
|
|
self.attn_mask = attn_mask
|
|
|
|
def attention(self, x: torch.Tensor):
|
|
"""Apply self-attention to the input tensor."""
|
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Forward pass through the residual attention block."""
|
|
x = x + self.attention(self.ln_1(x))
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
|
|
@dataclass
|
|
class UNet2DConditionOutput(BaseOutput):
|
|
"""The output of UnifiedUNet2DConditionModel."""
|
|
|
|
sample: torch.FloatTensor = None
|
|
|
|
|
|
class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
|
"""A unified 2D UNet model extending OriginalUNet2DConditionModel with custom functionality."""
|
|
|
|
_supports_gradient_checkpointing = True
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
sample_size: Optional[int] = None,
|
|
in_channels: int = 4,
|
|
out_channels: int = 4,
|
|
center_input_sample: bool = False,
|
|
flip_sin_to_cos: bool = True,
|
|
freq_shift: int = 0,
|
|
down_block_types: Tuple[str, ...] = (
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"DownBlock2D",
|
|
),
|
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
|
up_block_types: Tuple[str, ...] = (
|
|
"UpBlock2D",
|
|
"CrossAttnUpBlock2D",
|
|
"CrossAttnUpBlock2D",
|
|
"CrossAttnUpBlock2D",
|
|
),
|
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
|
downsample_padding: int = 1,
|
|
mid_block_scale_factor: float = 1,
|
|
dropout: float = 0.0,
|
|
act_fn: str = "silu",
|
|
norm_num_groups: Optional[int] = 32,
|
|
norm_eps: float = 1e-5,
|
|
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
|
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
|
encoder_hid_dim: Optional[int] = None,
|
|
encoder_hid_dim_type: Optional[str] = None,
|
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
|
dual_cross_attention: bool = False,
|
|
use_linear_projection: bool = False,
|
|
class_embed_type: Optional[str] = None,
|
|
addition_embed_type: Optional[str] = None,
|
|
addition_time_embed_dim: Optional[int] = None,
|
|
num_class_embeds: Optional[int] = None,
|
|
upcast_attention: bool = False,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_skip_time_act: bool = False,
|
|
resnet_out_scale_factor: float = 1.0,
|
|
time_embedding_type: str = "positional",
|
|
time_embedding_dim: Optional[int] = None,
|
|
time_embedding_act_fn: Optional[str] = None,
|
|
timestep_post_act: Optional[str] = None,
|
|
time_cond_proj_dim: Optional[int] = None,
|
|
conv_in_kernel: int = 3,
|
|
conv_out_kernel: int = 3,
|
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
|
attention_type: str = "default",
|
|
class_embeddings_concat: bool = False,
|
|
mid_block_only_cross_attention: Optional[bool] = None,
|
|
cross_attention_norm: Optional[str] = None,
|
|
addition_embed_type_num_heads: int = 64,
|
|
):
|
|
"""Initialize the UnifiedUNet2DConditionModel."""
|
|
super().__init__(
|
|
sample_size=sample_size,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
center_input_sample=center_input_sample,
|
|
flip_sin_to_cos=flip_sin_to_cos,
|
|
freq_shift=freq_shift,
|
|
down_block_types=down_block_types,
|
|
mid_block_type=mid_block_type,
|
|
up_block_types=up_block_types,
|
|
only_cross_attention=only_cross_attention,
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
downsample_padding=downsample_padding,
|
|
mid_block_scale_factor=mid_block_scale_factor,
|
|
dropout=dropout,
|
|
act_fn=act_fn,
|
|
norm_num_groups=norm_num_groups,
|
|
norm_eps=norm_eps,
|
|
cross_attention_dim=cross_attention_dim,
|
|
transformer_layers_per_block=transformer_layers_per_block,
|
|
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
|
encoder_hid_dim=encoder_hid_dim,
|
|
encoder_hid_dim_type=encoder_hid_dim_type,
|
|
attention_head_dim=attention_head_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
dual_cross_attention=dual_cross_attention,
|
|
use_linear_projection=use_linear_projection,
|
|
class_embed_type=class_embed_type,
|
|
addition_embed_type=addition_embed_type,
|
|
addition_time_embed_dim=addition_time_embed_dim,
|
|
num_class_embeds=num_class_embeds,
|
|
upcast_attention=upcast_attention,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_skip_time_act=resnet_skip_time_act,
|
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
|
time_embedding_type=time_embedding_type,
|
|
time_embedding_dim=time_embedding_dim,
|
|
time_embedding_act_fn=time_embedding_act_fn,
|
|
timestep_post_act=timestep_post_act,
|
|
time_cond_proj_dim=time_cond_proj_dim,
|
|
conv_in_kernel=conv_in_kernel,
|
|
conv_out_kernel=conv_out_kernel,
|
|
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
|
attention_type=attention_type,
|
|
class_embeddings_concat=class_embeddings_concat,
|
|
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
|
cross_attention_norm=cross_attention_norm,
|
|
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
|
)
|
|
|
|
# Additional attributes
|
|
self.denoise_encoder = None
|
|
self.information_transformer_layes = None
|
|
self.condition_embedding = None
|
|
self.agg_net = None
|
|
self.spatial_ch_projs = None
|
|
|
|
def init_vae_encoder(self, dtype):
|
|
self.denoise_encoder = Encoder()
|
|
if dtype is not None:
|
|
self.denoise_encoder.dtype = dtype
|
|
|
|
def init_information_transformer_layes(self):
|
|
num_trans_channel = 640
|
|
num_trans_head = 8
|
|
num_trans_layer = 2
|
|
num_proj_channel = 320
|
|
self.information_transformer_layes = nn.Sequential(
|
|
*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)]
|
|
)
|
|
self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
|
|
|
|
def init_ControlNetConditioningEmbedding(self, channel=512):
|
|
self.condition_embedding = ControlNetConditioningEmbedding(320, channel)
|
|
|
|
def init_extra_weights(self):
|
|
self.agg_net = nn.ModuleList()
|
|
|
|
def load_additional_layers(
|
|
self, dtype: Optional[torch.dtype] = torch.float16, channel: int = 512, weight_path: Optional[str] = None
|
|
):
|
|
"""Load additional layers and weights from a file.
|
|
|
|
Args:
|
|
weight_path (str): Path to the weight file.
|
|
dtype (torch.dtype, optional): Data type for the loaded weights. Defaults to torch.float16.
|
|
channel (int): Conditioning embedding channel out size. Defaults 512.
|
|
"""
|
|
if self.denoise_encoder is None:
|
|
self.init_vae_encoder(dtype)
|
|
|
|
if self.information_transformer_layes is None:
|
|
self.init_information_transformer_layes()
|
|
|
|
if self.condition_embedding is None:
|
|
self.init_ControlNetConditioningEmbedding(channel)
|
|
|
|
if self.agg_net is None:
|
|
self.init_extra_weights()
|
|
|
|
# Load weights if provided
|
|
if weight_path is not None:
|
|
state_dict = torch.load(weight_path, weights_only=False)
|
|
self.load_state_dict(state_dict, strict=True)
|
|
|
|
# Move all modules to the same device and dtype as the model
|
|
device = next(self.parameters()).device
|
|
if dtype is not None or device is not None:
|
|
self.to(device=device, dtype=dtype or next(self.parameters()).dtype)
|
|
|
|
def to(self, *args, **kwargs):
|
|
"""Override to() to move all additional modules to the same device and dtype."""
|
|
super().to(*args, **kwargs)
|
|
for module in [
|
|
self.denoise_encoder,
|
|
self.information_transformer_layes,
|
|
self.condition_embedding,
|
|
self.agg_net,
|
|
self.spatial_ch_projs,
|
|
]:
|
|
if module is not None:
|
|
module.to(*args, **kwargs)
|
|
return self
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
"""Load state dictionary into the model.
|
|
|
|
Args:
|
|
state_dict (dict): State dictionary to load.
|
|
strict (bool, optional): Whether to strictly enforce that all keys match. Defaults to True.
|
|
"""
|
|
core_dict = {}
|
|
additional_dicts = {
|
|
"denoise_encoder": {},
|
|
"information_transformer_layes": {},
|
|
"condition_embedding": {},
|
|
"agg_net": {},
|
|
"spatial_ch_projs": {},
|
|
}
|
|
|
|
for key, value in state_dict.items():
|
|
if key.startswith("denoise_encoder."):
|
|
additional_dicts["denoise_encoder"][key[len("denoise_encoder.") :]] = value
|
|
elif key.startswith("information_transformer_layes."):
|
|
additional_dicts["information_transformer_layes"][key[len("information_transformer_layes.") :]] = value
|
|
elif key.startswith("condition_embedding."):
|
|
additional_dicts["condition_embedding"][key[len("condition_embedding.") :]] = value
|
|
elif key.startswith("agg_net."):
|
|
additional_dicts["agg_net"][key[len("agg_net.") :]] = value
|
|
elif key.startswith("spatial_ch_projs."):
|
|
additional_dicts["spatial_ch_projs"][key[len("spatial_ch_projs.") :]] = value
|
|
else:
|
|
core_dict[key] = value
|
|
|
|
super().load_state_dict(core_dict, strict=False)
|
|
for module_name, module_dict in additional_dicts.items():
|
|
module = getattr(self, module_name, None)
|
|
if module is not None and module_dict:
|
|
module.load_state_dict(module_dict, strict=strict)
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
timestep: Union[torch.Tensor, float, int],
|
|
encoder_hidden_states: torch.Tensor,
|
|
class_labels: Optional[torch.Tensor] = None,
|
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
input_embedding: Optional[torch.Tensor] = None,
|
|
add_sample: bool = True,
|
|
return_dict: bool = True,
|
|
use_condition_embedding: bool = True,
|
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
"""Forward pass prioritizing the original modified implementation.
|
|
|
|
Args:
|
|
sample (torch.FloatTensor): The noisy input tensor with shape `(batch, channel, height, width)`.
|
|
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
encoder_hidden_states (torch.Tensor): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
|
class_labels (torch.Tensor, optional): Optional class labels for conditioning.
|
|
timestep_cond (torch.Tensor, optional): Conditional embeddings for timestep.
|
|
attention_mask (torch.Tensor, optional): An attention mask of shape `(batch, key_tokens)`.
|
|
cross_attention_kwargs (Dict[str, Any], optional): A kwargs dictionary for the AttentionProcessor.
|
|
added_cond_kwargs (Dict[str, torch.Tensor], optional): Additional embeddings to add to the UNet blocks.
|
|
down_block_additional_residuals (Tuple[torch.Tensor], optional): Residuals for down UNet blocks.
|
|
mid_block_additional_residual (torch.Tensor, optional): Residual for the middle UNet block.
|
|
down_intrablock_additional_residuals (Tuple[torch.Tensor], optional): Additional residuals within down blocks.
|
|
encoder_attention_mask (torch.Tensor, optional): A cross-attention mask of shape `(batch, sequence_length)`.
|
|
input_embedding (torch.Tensor, optional): Additional input embedding for preprocessing.
|
|
add_sample (bool): Whether to add the sample to the processed embedding. Defaults to True.
|
|
return_dict (bool): Whether to return a UNet2DConditionOutput. Defaults to True.
|
|
use_condition_embedding (bool): Whether to use the condition embedding. Defaults to True.
|
|
|
|
Returns:
|
|
Union[UNet2DConditionOutput, Tuple]: The processed sample tensor, either as a UNet2DConditionOutput or tuple.
|
|
"""
|
|
default_overall_up_factor = 2**self.num_upsamplers
|
|
forward_upsample_size = False
|
|
upsample_size = None
|
|
|
|
for dim in sample.shape[-2:]:
|
|
if dim % default_overall_up_factor != 0:
|
|
forward_upsample_size = True
|
|
break
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
if encoder_attention_mask is not None:
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
|
|
if self.config.center_input_sample:
|
|
sample = 2 * sample - 1.0
|
|
|
|
# 1. time
|
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
|
emb = self.time_embedding(t_emb, timestep_cond)
|
|
aug_emb = None
|
|
|
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
|
if class_emb is not None:
|
|
if self.config.class_embeddings_concat:
|
|
emb = torch.cat([emb, class_emb], dim=-1)
|
|
else:
|
|
emb = emb + class_emb
|
|
|
|
aug_emb = self.get_aug_embed(
|
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
|
)
|
|
if self.config.addition_embed_type == "image_hint":
|
|
aug_emb, hint = aug_emb
|
|
sample = torch.cat([sample, hint], dim=1)
|
|
|
|
emb = emb + aug_emb if aug_emb is not None else emb
|
|
|
|
if self.time_embed_act is not None:
|
|
emb = self.time_embed_act(emb)
|
|
|
|
encoder_hidden_states = self.process_encoder_hidden_states(
|
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
|
)
|
|
|
|
# 2. pre-process (following the original modified logic)
|
|
sample = self.conv_in(sample) # [B, 4, H, W] -> [B, 320, H, W]
|
|
if (
|
|
input_embedding is not None
|
|
and self.condition_embedding is not None
|
|
and self.information_transformer_layes is not None
|
|
):
|
|
if use_condition_embedding:
|
|
input_embedding = self.condition_embedding(input_embedding) # [B, 320, H, W]
|
|
batch_size, channel, height, width = input_embedding.shape
|
|
concat_feat = (
|
|
torch.cat([sample, input_embedding], dim=1)
|
|
.view(batch_size, 2 * channel, height * width)
|
|
.transpose(1, 2)
|
|
)
|
|
concat_feat = self.information_transformer_layes(concat_feat)
|
|
feat_alpha = self.spatial_ch_projs(concat_feat).transpose(1, 2).view(batch_size, channel, height, width)
|
|
sample = sample + feat_alpha if add_sample else feat_alpha # Update sample as in the original version
|
|
|
|
# 2.5 GLIGEN position net (kept from the original version)
|
|
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
|
gligen_args = cross_attention_kwargs.pop("gligen")
|
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
|
|
|
# 3. down (continues the standard flow)
|
|
if cross_attention_kwargs is not None:
|
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
|
else:
|
|
lora_scale = 1.0
|
|
|
|
if USE_PEFT_BACKEND:
|
|
scale_lora_layers(self, lora_scale)
|
|
|
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
|
is_adapter = down_intrablock_additional_residuals is not None
|
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
|
deprecate(
|
|
"T2I should not use down_block_additional_residuals",
|
|
"1.3.0",
|
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
|
standard_warn=False,
|
|
)
|
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
|
is_adapter = True
|
|
|
|
down_block_res_samples = (sample,)
|
|
for downsample_block in self.down_blocks:
|
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
|
additional_residuals = {}
|
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
|
sample, res_samples = downsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
**additional_residuals,
|
|
)
|
|
else:
|
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
|
sample += down_intrablock_additional_residuals.pop(0)
|
|
down_block_res_samples += res_samples
|
|
|
|
if is_controlnet:
|
|
new_down_block_res_samples = ()
|
|
for down_block_res_sample, down_block_additional_residual in zip(
|
|
down_block_res_samples, down_block_additional_residuals
|
|
):
|
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
|
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
|
down_block_res_samples = new_down_block_res_samples
|
|
|
|
# 4. mid
|
|
if self.mid_block is not None:
|
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
|
sample = self.mid_block(
|
|
sample,
|
|
emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
else:
|
|
sample = self.mid_block(sample, emb)
|
|
if (
|
|
is_adapter
|
|
and len(down_intrablock_additional_residuals) > 0
|
|
and sample.shape == down_intrablock_additional_residuals[0].shape
|
|
):
|
|
sample += down_intrablock_additional_residuals.pop(0)
|
|
|
|
if is_controlnet:
|
|
sample = sample + mid_block_additional_residual
|
|
|
|
# 5. up
|
|
for i, upsample_block in enumerate(self.up_blocks):
|
|
is_final_block = i == len(self.up_blocks) - 1
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
|
if not is_final_block and forward_upsample_size:
|
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
|
sample = upsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
res_hidden_states_tuple=res_samples,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
upsample_size=upsample_size,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
else:
|
|
sample = upsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
res_hidden_states_tuple=res_samples,
|
|
upsample_size=upsample_size,
|
|
)
|
|
|
|
# 6. post-process
|
|
if self.conv_norm_out:
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
if USE_PEFT_BACKEND:
|
|
unscale_lora_layers(self, lora_scale)
|
|
|
|
if not return_dict:
|
|
return (sample,)
|
|
return UNet2DConditionOutput(sample=sample)
|
|
|
|
|
|
class LocalAttention:
|
|
"""A class to handle local attention by splitting tensors into overlapping grids for processing."""
|
|
|
|
def __init__(self, kernel_size=None, overlap=0.5):
|
|
"""Initialize the LocalAttention module.
|
|
|
|
Args:
|
|
kernel_size (tuple[int, int], optional): Size of the grid (height, width). Defaults to None.
|
|
overlap (float): Overlap factor between adjacent grids (0.0 to 1.0). Defaults to 0.5.
|
|
"""
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
self.overlap = overlap
|
|
|
|
def grids_list(self, x):
|
|
"""Split the input tensor into a list of non-overlapping grid patches.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
|
|
|
|
Returns:
|
|
list[torch.Tensor]: List of tensor patches.
|
|
"""
|
|
b, c, h, w = x.shape
|
|
self.original_size = (b, c, h, w)
|
|
assert b == 1
|
|
k1, k2 = self.kernel_size
|
|
if h < k1:
|
|
k1 = h
|
|
if w < k2:
|
|
k2 = w
|
|
num_row = (h - 1) // k1 + 1
|
|
num_col = (w - 1) // k2 + 1
|
|
self.nr = num_row
|
|
self.nc = num_col
|
|
|
|
import math
|
|
|
|
step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
|
|
step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
|
|
parts = []
|
|
idxes = []
|
|
i = 0
|
|
last_i = False
|
|
while i < h and not last_i:
|
|
j = 0
|
|
if i + k1 >= h:
|
|
i = h - k1
|
|
last_i = True
|
|
last_j = False
|
|
while j < w and not last_j:
|
|
if j + k2 >= w:
|
|
j = w - k2
|
|
last_j = True
|
|
parts.append(x[:, :, i : i + k1, j : j + k2])
|
|
idxes.append({"i": i, "j": j})
|
|
j = j + step_j
|
|
i = i + step_i
|
|
return parts
|
|
|
|
def grids(self, x):
|
|
"""Split the input tensor into overlapping grid patches and concatenate them.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
|
|
|
|
Returns:
|
|
torch.Tensor: Concatenated tensor of all grid patches.
|
|
"""
|
|
b, c, h, w = x.shape
|
|
self.original_size = (b, c, h, w)
|
|
assert b == 1
|
|
k1, k2 = self.kernel_size
|
|
if h < k1:
|
|
k1 = h
|
|
if w < k2:
|
|
k2 = w
|
|
self.tile_weights = self._gaussian_weights(k2, k1)
|
|
num_row = (h - 1) // k1 + 1
|
|
num_col = (w - 1) // k2 + 1
|
|
self.nr = num_row
|
|
self.nc = num_col
|
|
|
|
import math
|
|
|
|
step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
|
|
step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
|
|
parts = []
|
|
idxes = []
|
|
i = 0
|
|
last_i = False
|
|
while i < h and not last_i:
|
|
j = 0
|
|
if i + k1 >= h:
|
|
i = h - k1
|
|
last_i = True
|
|
last_j = False
|
|
while j < w and not last_j:
|
|
if j + k2 >= w:
|
|
j = w - k2
|
|
last_j = True
|
|
parts.append(x[:, :, i : i + k1, j : j + k2])
|
|
idxes.append({"i": i, "j": j})
|
|
j = j + step_j
|
|
i = i + step_i
|
|
self.idxes = idxes
|
|
return torch.cat(parts, dim=0)
|
|
|
|
def _gaussian_weights(self, tile_width, tile_height):
|
|
"""Generate a Gaussian weight mask for tile contributions.
|
|
|
|
Args:
|
|
tile_width (int): Width of the tile.
|
|
tile_height (int): Height of the tile.
|
|
|
|
Returns:
|
|
torch.Tensor: Gaussian weight tensor of shape (channels, height, width).
|
|
"""
|
|
import numpy as np
|
|
from numpy import exp, pi, sqrt
|
|
|
|
latent_width = tile_width
|
|
latent_height = tile_height
|
|
var = 0.01
|
|
midpoint = (latent_width - 1) / 2
|
|
x_probs = [
|
|
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
|
|
for x in range(latent_width)
|
|
]
|
|
midpoint = latent_height / 2
|
|
y_probs = [
|
|
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
|
|
for y in range(latent_height)
|
|
]
|
|
weights = np.outer(y_probs, x_probs)
|
|
return torch.tile(torch.tensor(weights, device=torch.device("cuda")), (4, 1, 1))
|
|
|
|
def grids_inverse(self, outs):
|
|
"""Reconstruct the original tensor from processed grid patches with overlap blending.
|
|
|
|
Args:
|
|
outs (torch.Tensor): Processed grid patches.
|
|
|
|
Returns:
|
|
torch.Tensor: Reconstructed tensor of original size.
|
|
"""
|
|
preds = torch.zeros(self.original_size).to(outs.device)
|
|
b, c, h, w = self.original_size
|
|
count_mt = torch.zeros((b, 4, h, w)).to(outs.device)
|
|
k1, k2 = self.kernel_size
|
|
|
|
for cnt, each_idx in enumerate(self.idxes):
|
|
i = each_idx["i"]
|
|
j = each_idx["j"]
|
|
preds[0, :, i : i + k1, j : j + k2] += outs[cnt, :, :, :] * self.tile_weights
|
|
count_mt[0, :, i : i + k1, j : j + k2] += self.tile_weights
|
|
|
|
del outs
|
|
torch.cuda.empty_cache()
|
|
return preds / count_mt
|
|
|
|
def _pad(self, x):
|
|
"""Pad the input tensor to align with kernel size.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
|
|
|
|
Returns:
|
|
tuple: Padded tensor and padding values.
|
|
"""
|
|
b, c, h, w = x.shape
|
|
k1, k2 = self.kernel_size
|
|
mod_pad_h = (k1 - h % k1) % k1
|
|
mod_pad_w = (k2 - w % k2) % k2
|
|
pad = (mod_pad_w // 2, mod_pad_w - mod_pad_w // 2, mod_pad_h // 2, mod_pad_h - mod_pad_h // 2)
|
|
x = F.pad(x, pad, "reflect")
|
|
return x, pad
|
|
|
|
def forward(self, x):
|
|
"""Apply local attention by splitting into grids and reconstructing.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
|
|
|
|
Returns:
|
|
torch.Tensor: Processed tensor of original size.
|
|
"""
|
|
b, c, h, w = x.shape
|
|
qkv = self.grids(x)
|
|
out = self.grids_inverse(qkv)
|
|
return out
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
"""
|
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4
|
|
|
|
Args:
|
|
noise_cfg (torch.Tensor): Noise configuration tensor.
|
|
noise_pred_text (torch.Tensor): Predicted noise from text-conditioned model.
|
|
guidance_rescale (float): Rescaling factor for guidance. Defaults to 0.0.
|
|
|
|
Returns:
|
|
torch.Tensor: Rescaled noise configuration.
|
|
"""
|
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
|
# rescale the results from guidance (fixes overexposure)
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
|
return noise_cfg
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
|
def retrieve_latents(
|
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
|
):
|
|
"""Retrieve latents from an encoder output.
|
|
|
|
Args:
|
|
encoder_output (torch.Tensor): Output from an encoder (e.g., VAE).
|
|
generator (torch.Generator, optional): Random generator for sampling. Defaults to None.
|
|
sample_mode (str): Sampling mode ("sample" or "argmax"). Defaults to "sample".
|
|
|
|
Returns:
|
|
torch.Tensor: Retrieved latent tensor.
|
|
"""
|
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
|
return encoder_output.latent_dist.sample(generator)
|
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
|
return encoder_output.latent_dist.mode()
|
|
elif hasattr(encoder_output, "latents"):
|
|
return encoder_output.latents
|
|
else:
|
|
raise AttributeError("Could not access latents of provided encoder_output")
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
|
def retrieve_timesteps(
|
|
scheduler,
|
|
num_inference_steps: Optional[int] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
timesteps: Optional[List[int]] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
|
Args:
|
|
scheduler (`SchedulerMixin`):
|
|
The scheduler to get timesteps from.
|
|
num_inference_steps (`int`):
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
|
must be `None`.
|
|
device (`str` or `torch.device`, *optional*):
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
timesteps (`List[int]`, *optional*):
|
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
|
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
|
must be `None`.
|
|
|
|
Returns:
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
|
second element is the number of inference steps.
|
|
"""
|
|
if timesteps is not None:
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accepts_timesteps:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
else:
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
return timesteps, num_inference_steps
|
|
|
|
|
|
class FaithDiffStableDiffusionXLPipeline(
|
|
DiffusionPipeline,
|
|
StableDiffusionMixin,
|
|
FromSingleFileMixin,
|
|
StableDiffusionXLLoraLoaderMixin,
|
|
TextualInversionLoaderMixin,
|
|
IPAdapterMixin,
|
|
):
|
|
r"""
|
|
Pipeline for text-to-image generation using Stable Diffusion XL.
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
|
The pipeline also inherits the following loading methods:
|
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
|
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
|
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
|
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
|
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
|
|
|
Args:
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
text_encoder ([`CLIPTextModel`]):
|
|
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
|
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
|
specifically the
|
|
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
|
variant.
|
|
tokenizer (`CLIPTokenizer`):
|
|
Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
tokenizer_2 (`CLIPTokenizer`):
|
|
Second Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
scheduler ([`SchedulerMixin`]):
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
|
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
|
`stabilityai/stable-diffusion-xl-base-1-0`.
|
|
add_watermarker (`bool`, *optional*):
|
|
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
|
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
|
watermarker will be used.
|
|
"""
|
|
|
|
unet_model = UNet2DConditionModel
|
|
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
|
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "unet"]
|
|
_callback_tensor_inputs = [
|
|
"latents",
|
|
"prompt_embeds",
|
|
"negative_prompt_embeds",
|
|
"add_text_embeds",
|
|
"add_time_ids",
|
|
"negative_pooled_prompt_embeds",
|
|
"negative_add_time_ids",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
vae: AutoencoderKL,
|
|
text_encoder: CLIPTextModel,
|
|
text_encoder_2: CLIPTextModelWithProjection,
|
|
tokenizer: CLIPTokenizer,
|
|
tokenizer_2: CLIPTokenizer,
|
|
unet: OriginalUNet2DConditionModel,
|
|
scheduler: KarrasDiffusionSchedulers,
|
|
force_zeros_for_empty_prompt: bool = True,
|
|
add_watermarker: Optional[bool] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
tokenizer=tokenizer,
|
|
tokenizer_2=tokenizer_2,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
)
|
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
|
self.DDPMScheduler = DDPMScheduler.from_config(self.scheduler.config, subfolder="scheduler")
|
|
self.default_sample_size = self.unet.config.sample_size if unet is not None else 128
|
|
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
|
|
|
if add_watermarker:
|
|
self.watermark = StableDiffusionXLWatermarker()
|
|
else:
|
|
self.watermark = None
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt: str,
|
|
prompt_2: Optional[str] = None,
|
|
device: Optional[torch.device] = None,
|
|
num_images_per_prompt: int = 1,
|
|
do_classifier_free_guidance: bool = True,
|
|
negative_prompt: Optional[str] = None,
|
|
negative_prompt_2: Optional[str] = None,
|
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
lora_scale: Optional[float] = None,
|
|
clip_skip: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Encodes the prompt into text encoder hidden states.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
prompt to be encoded
|
|
prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
|
used in both text-encoders
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
do_classifier_free_guidance (`bool`):
|
|
whether to use classifier free guidance or not
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
less than `1`).
|
|
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
argument.
|
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
|
input argument.
|
|
lora_scale (`float`, *optional*):
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
|
clip_skip (`int`, *optional*):
|
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
"""
|
|
device = "cuda" # device or self._execution_device
|
|
|
|
# set lora scale so that monkey patched LoRA
|
|
# function of text encoder can correctly access it
|
|
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
|
self._lora_scale = lora_scale
|
|
|
|
# dynamically adjust the LoRA scale
|
|
if self.text_encoder is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if prompt is not None:
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
# Define tokenizers and text encoders
|
|
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
|
text_encoders = (
|
|
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
|
)
|
|
dtype = text_encoders[0].dtype
|
|
if prompt_embeds is None:
|
|
prompt_2 = prompt_2 or prompt
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
|
|
|
# textual inversion: process multi-vector tokens if necessary
|
|
prompt_embeds_list = []
|
|
prompts = [prompt, prompt_2]
|
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
|
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
text_input_ids, untruncated_ids
|
|
):
|
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
text_encoder = text_encoder.to(dtype)
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
|
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
if clip_skip is None:
|
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
else:
|
|
# "2" because SDXL always indexes from the penultimate layer.
|
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
|
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
|
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
negative_prompt = negative_prompt or ""
|
|
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
|
|
|
# normalize str to list
|
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
|
negative_prompt_2 = (
|
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
|
)
|
|
|
|
uncond_tokens: List[str]
|
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
else:
|
|
uncond_tokens = [negative_prompt, negative_prompt_2]
|
|
|
|
negative_prompt_embeds_list = []
|
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
|
|
|
max_length = prompt_embeds.shape[1]
|
|
uncond_input = tokenizer(
|
|
negative_prompt,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
negative_prompt_embeds = text_encoder(
|
|
uncond_input.input_ids.to(device),
|
|
output_hidden_states=True,
|
|
)
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
|
|
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
|
|
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
if do_classifier_free_guidance:
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
seq_len = negative_prompt_embeds.shape[1]
|
|
|
|
if self.text_encoder_2 is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
if do_classifier_free_guidance:
|
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
|
def prepare_extra_step_kwargs(self, generator, eta):
|
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
|
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
|
# and should be between [0, 1]
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
extra_step_kwargs = {}
|
|
if accepts_eta:
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
# check if the scheduler accepts generator
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
if accepts_generator:
|
|
extra_step_kwargs["generator"] = generator
|
|
return extra_step_kwargs
|
|
|
|
def check_image_size(self, x, padder_size=8):
|
|
# 获取图像的宽高
|
|
width, height = x.size
|
|
padder_size = padder_size
|
|
# 计算需要填充的高度和宽度
|
|
mod_pad_h = (padder_size - height % padder_size) % padder_size
|
|
mod_pad_w = (padder_size - width % padder_size) % padder_size
|
|
x_np = np.array(x)
|
|
# 使用 ImageOps.expand 进行填充
|
|
x_padded = cv2.copyMakeBorder(
|
|
x_np, top=0, bottom=mod_pad_h, left=0, right=mod_pad_w, borderType=cv2.BORDER_REPLICATE
|
|
)
|
|
|
|
x = PIL.Image.fromarray(x_padded)
|
|
# x = x.resize((width + mod_pad_w, height + mod_pad_h))
|
|
|
|
return x, width, height, width + mod_pad_w, height + mod_pad_h
|
|
|
|
def check_inputs(
|
|
self,
|
|
lr_img,
|
|
prompt,
|
|
prompt_2,
|
|
height,
|
|
width,
|
|
callback_steps,
|
|
negative_prompt=None,
|
|
negative_prompt_2=None,
|
|
prompt_embeds=None,
|
|
negative_prompt_embeds=None,
|
|
pooled_prompt_embeds=None,
|
|
negative_pooled_prompt_embeds=None,
|
|
callback_on_step_end_tensor_inputs=None,
|
|
):
|
|
if lr_img is None:
|
|
raise ValueError("`lr_image` must be provided!")
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
|
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
|
raise ValueError(
|
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
f" {type(callback_steps)}."
|
|
)
|
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
|
):
|
|
raise ValueError(
|
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
|
)
|
|
|
|
if prompt is not None and prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
|
" only forward one of the two."
|
|
)
|
|
elif prompt_2 is not None and prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
|
" only forward one of the two."
|
|
)
|
|
elif prompt is None and prompt_embeds is None:
|
|
raise ValueError(
|
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
|
)
|
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
|
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
)
|
|
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
)
|
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
|
raise ValueError(
|
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
|
f" {negative_prompt_embeds.shape}."
|
|
)
|
|
|
|
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
|
raise ValueError(
|
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
|
)
|
|
|
|
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
|
raise ValueError(
|
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
|
)
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
if isinstance(generator, list) and len(generator) != batch_size:
|
|
raise ValueError(
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
|
)
|
|
|
|
if latents is None:
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
else:
|
|
latents = latents.to(device)
|
|
|
|
# scale the initial noise by the standard deviation required by the scheduler
|
|
latents = latents * self.scheduler.init_noise_sigma
|
|
return latents
|
|
|
|
def upcast_vae(self):
|
|
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
|
self.vae.to(dtype=torch.float32)
|
|
|
|
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
|
def get_guidance_scale_embedding(
|
|
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
|
|
|
Args:
|
|
w (`torch.Tensor`):
|
|
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
|
embedding_dim (`int`, *optional*, defaults to 512):
|
|
Dimension of the embeddings to generate.
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
Data type of the generated embeddings.
|
|
|
|
Returns:
|
|
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
|
"""
|
|
assert len(w.shape) == 1
|
|
w = w * 1000.0
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
|
emb = w.to(dtype)[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = torch.nn.functional.pad(emb, (0, 1))
|
|
assert emb.shape == (w.shape[0], embedding_dim)
|
|
return emb
|
|
|
|
def set_encoder_tile_settings(
|
|
self,
|
|
denoise_encoder_tile_sample_min_size=1024,
|
|
denoise_encoder_sample_overlap_factor=0.25,
|
|
vae_sample_size=1024,
|
|
vae_tile_overlap_factor=0.25,
|
|
):
|
|
self.unet.denoise_encoder.tile_sample_min_size = denoise_encoder_tile_sample_min_size
|
|
self.unet.denoise_encoder.tile_overlap_factor = denoise_encoder_sample_overlap_factor
|
|
self.vae.config.sample_size = vae_sample_size
|
|
self.vae.tile_overlap_factor = vae_tile_overlap_factor
|
|
|
|
def enable_vae_tiling(self):
|
|
r"""
|
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
|
processing larger images.
|
|
"""
|
|
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
|
deprecate(
|
|
"enable_vae_tiling",
|
|
"0.40.0",
|
|
depr_message,
|
|
)
|
|
self.vae.enable_tiling()
|
|
self.unet.denoise_encoder.enable_tiling()
|
|
|
|
def disable_vae_tiling(self):
|
|
r"""
|
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
|
computing decoding in one step.
|
|
"""
|
|
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
|
deprecate(
|
|
"disable_vae_tiling",
|
|
"0.40.0",
|
|
depr_message,
|
|
)
|
|
self.vae.disable_tiling()
|
|
self.unet.denoise_encoder.disable_tiling()
|
|
|
|
@property
|
|
def guidance_scale(self):
|
|
return self._guidance_scale
|
|
|
|
@property
|
|
def guidance_rescale(self):
|
|
return self._guidance_rescale
|
|
|
|
@property
|
|
def clip_skip(self):
|
|
return self._clip_skip
|
|
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
|
# corresponds to doing no classifier free guidance.
|
|
@property
|
|
def do_classifier_free_guidance(self):
|
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
|
|
|
@property
|
|
def cross_attention_kwargs(self):
|
|
return self._cross_attention_kwargs
|
|
|
|
@property
|
|
def denoising_end(self):
|
|
return self._denoising_end
|
|
|
|
@property
|
|
def num_timesteps(self):
|
|
return self._num_timesteps
|
|
|
|
@property
|
|
def interrupt(self):
|
|
return self._interrupt
|
|
|
|
def prepare_image_latents(
|
|
self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
|
|
):
|
|
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
|
raise ValueError(
|
|
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
|
)
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
|
|
batch_size = batch_size * num_images_per_prompt
|
|
|
|
if image.shape[1] == 4:
|
|
image_latents = image
|
|
else:
|
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
# needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
|
# if needs_upcasting:
|
|
# image = image.float()
|
|
# self.upcast_vae()
|
|
self.unet.denoise_encoder.to(device=image.device, dtype=image.dtype)
|
|
image_latents = self.unet.denoise_encoder(image)
|
|
self.unet.denoise_encoder.to("cpu")
|
|
# cast back to fp16 if needed
|
|
# if needs_upcasting:
|
|
# self.vae.to(dtype=torch.float16)
|
|
|
|
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
|
# expand image_latents for batch_size
|
|
deprecation_message = (
|
|
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
|
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
|
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
|
" your script to pass as many initial images as text prompts to suppress this warning."
|
|
)
|
|
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
|
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
|
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
|
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
|
raise ValueError(
|
|
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
|
)
|
|
else:
|
|
image_latents = torch.cat([image_latents], dim=0)
|
|
|
|
if do_classifier_free_guidance:
|
|
image_latents = image_latents
|
|
|
|
if image_latents.dtype != self.vae.dtype:
|
|
image_latents = image_latents.to(dtype=self.vae.dtype)
|
|
|
|
return image_latents
|
|
|
|
@torch.no_grad()
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
def __call__(
|
|
self,
|
|
lr_img: PipelineImageInput = None,
|
|
prompt: Union[str, List[str]] = None,
|
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
num_inference_steps: int = 50,
|
|
start_point: Optional[str] = "noise",
|
|
timesteps: List[int] = None,
|
|
denoising_end: Optional[float] = None,
|
|
overlap: float = 0.5,
|
|
guidance_scale: float = 5.0,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
latents: Optional[torch.FloatTensor] = None,
|
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
guidance_rescale: float = 0.0,
|
|
original_size: Optional[Tuple[int, int]] = None,
|
|
target_size: Optional[Tuple[int, int]] = None,
|
|
clip_skip: Optional[int] = None,
|
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
add_sample: bool = True,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Function invoked when calling the pipeline for generation.
|
|
|
|
Args:
|
|
lr_img (PipelineImageInput, optional): Low-resolution input image for conditioning the generation process.
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
instead.
|
|
prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
|
used in both text-encoders
|
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
Anything below 512 pixels won't work well for
|
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
Anything below 512 pixels won't work well for
|
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
start_point (str, *optional*):
|
|
The starting point for the generation process. Can be "noise" (random noise) or "lr" (low-resolution image).
|
|
Defaults to "noise".
|
|
timesteps (`List[int]`, *optional*):
|
|
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
|
passed will be used. Must be in descending order.
|
|
denoising_end (`float`, *optional*):
|
|
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
|
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
|
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
|
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
|
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
|
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
|
overlap (float):
|
|
Overlap factor for local attention tiling (between 0.0 and 1.0). Controls the overlap between adjacent
|
|
grid patches during processing. Defaults to 0.5.
|
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >
|
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
usually at the expense of lower image quality.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
less than `1`).
|
|
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
eta (`float`, *optional*, defaults to 0.0):
|
|
Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to
|
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
|
to make generation deterministic.
|
|
latents (`torch.FloatTensor`, *optional*):
|
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
tensor will be generated by sampling using the supplied random `generator`.
|
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
argument.
|
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
|
input argument.
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generate image. Choose between
|
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
|
of a plain tuple.
|
|
cross_attention_kwargs (`dict`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
`self.processor` in
|
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
|
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
|
|
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
|
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
|
explained in section 2.2 of
|
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
|
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
|
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
|
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
|
clip_skip (`int`, *optional*):
|
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
callback_on_step_end (`Callable`, *optional*):
|
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
|
`callback_on_step_end_tensor_inputs`.
|
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
|
add_sample (bool):
|
|
Whether to include sample conditioning (e.g., low-resolution image) in the UNet during denoising.
|
|
Defaults to True.
|
|
|
|
Examples:
|
|
|
|
Returns:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
|
"""
|
|
|
|
callback = kwargs.pop("callback", None)
|
|
callback_steps = kwargs.pop("callback_steps", None)
|
|
|
|
if callback is not None:
|
|
deprecate(
|
|
"callback",
|
|
"1.0.0",
|
|
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
|
)
|
|
if callback_steps is not None:
|
|
deprecate(
|
|
"callback_steps",
|
|
"1.0.0",
|
|
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
|
)
|
|
|
|
# 0. Default height and width to unet
|
|
height = height or self.default_sample_size * self.vae_scale_factor
|
|
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
original_size = original_size or (height, width)
|
|
target_size = target_size or (height, width)
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(
|
|
lr_img,
|
|
prompt,
|
|
prompt_2,
|
|
height,
|
|
width,
|
|
callback_steps,
|
|
negative_prompt,
|
|
negative_prompt_2,
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
pooled_prompt_embeds,
|
|
negative_pooled_prompt_embeds,
|
|
callback_on_step_end_tensor_inputs,
|
|
)
|
|
self._guidance_scale = guidance_scale
|
|
self._guidance_rescale = guidance_rescale
|
|
self._clip_skip = clip_skip
|
|
self._cross_attention_kwargs = cross_attention_kwargs
|
|
self._denoising_end = denoising_end
|
|
self._interrupt = False
|
|
self.tlc_vae_latents = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)
|
|
self.tlc_vae_img = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)
|
|
|
|
# 2. Define call parameters
|
|
batch_size = 1
|
|
num_images_per_prompt = 1
|
|
|
|
device = torch.device("cuda") # self._execution_device
|
|
|
|
# 3. Encode input prompt
|
|
lora_scale = (
|
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
)
|
|
|
|
num_samples = num_images_per_prompt
|
|
with torch.inference_mode():
|
|
(
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
pooled_prompt_embeds,
|
|
negative_pooled_prompt_embeds,
|
|
) = self.encode_prompt(
|
|
prompt,
|
|
num_images_per_prompt=num_samples,
|
|
do_classifier_free_guidance=True,
|
|
negative_prompt=negative_prompt,
|
|
lora_scale=lora_scale,
|
|
)
|
|
|
|
lr_img_list = [lr_img]
|
|
lr_img = self.image_processor.preprocess(lr_img_list, height=height, width=width).to(
|
|
device, dtype=prompt_embeds.dtype
|
|
)
|
|
|
|
# 4. Prepare timesteps
|
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
|
image_latents = self.prepare_image_latents(
|
|
lr_img, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, self.do_classifier_free_guidance
|
|
)
|
|
|
|
image_latents = self.tlc_vae_img.grids(image_latents)
|
|
|
|
# 5. Prepare latent variables
|
|
num_channels_latents = self.vae.config.latent_channels
|
|
latents = self.prepare_latents(
|
|
batch_size * num_images_per_prompt,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
prompt_embeds.dtype,
|
|
device,
|
|
generator,
|
|
latents,
|
|
)
|
|
if start_point == "lr":
|
|
latents_condition_image = self.vae.encode(lr_img * 2 - 1).latent_dist.sample()
|
|
latents_condition_image = latents_condition_image * self.vae.config.scaling_factor
|
|
start_steps_tensor = torch.randint(999, 999 + 1, (latents.shape[0],), device=latents.device)
|
|
start_steps_tensor = start_steps_tensor.long()
|
|
latents = self.DDPMScheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)
|
|
|
|
latents = self.tlc_vae_latents.grids(latents)
|
|
|
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * image_latents.shape[0]
|
|
|
|
# 7. Prepare added time ids & embeddings
|
|
add_text_embeds = pooled_prompt_embeds
|
|
|
|
if self.do_classifier_free_guidance:
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
|
|
|
prompt_embeds = prompt_embeds.to(device)
|
|
add_text_embeds = add_text_embeds.to(device)
|
|
|
|
# 8. Denoising loop
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
|
# 8.1 Apply denoising_end
|
|
if (
|
|
self.denoising_end is not None
|
|
and isinstance(self.denoising_end, float)
|
|
and self.denoising_end > 0
|
|
and self.denoising_end < 1
|
|
):
|
|
discrete_timestep_cutoff = int(
|
|
round(
|
|
self.scheduler.config.num_train_timesteps
|
|
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
|
)
|
|
)
|
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
|
timesteps = timesteps[:num_inference_steps]
|
|
|
|
# 9. Optionally get Guidance Scale Embedding
|
|
timestep_cond = None
|
|
if self.unet.config.time_cond_proj_dim is not None:
|
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
|
timestep_cond = self.get_guidance_scale_embedding(
|
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
|
).to(device=device, dtype=latents.dtype)
|
|
|
|
self._num_timesteps = len(timesteps)
|
|
sub_latents_num = latents.shape[0]
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if i >= 1:
|
|
latents = self.tlc_vae_latents.grids(latents).to(dtype=latents.dtype)
|
|
if self.interrupt:
|
|
continue
|
|
concat_grid = []
|
|
for sub_num in range(sub_latents_num):
|
|
self.scheduler.__dict__.update(views_scheduler_status[sub_num])
|
|
sub_latents = latents[sub_num, :, :, :].unsqueeze(0)
|
|
img_sub_latents = image_latents[sub_num, :, :, :].unsqueeze(0)
|
|
latent_model_input = (
|
|
torch.cat([sub_latents] * 2) if self.do_classifier_free_guidance else sub_latents
|
|
)
|
|
img_sub_latents = (
|
|
torch.cat([img_sub_latents] * 2) if self.do_classifier_free_guidance else img_sub_latents
|
|
)
|
|
scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
pos_height = self.tlc_vae_latents.idxes[sub_num]["i"]
|
|
pos_width = self.tlc_vae_latents.idxes[sub_num]["j"]
|
|
add_time_ids = [
|
|
torch.tensor([original_size]),
|
|
torch.tensor([[pos_height, pos_width]]),
|
|
torch.tensor([target_size]),
|
|
]
|
|
add_time_ids = torch.cat(add_time_ids, dim=1).to(
|
|
img_sub_latents.device, dtype=img_sub_latents.dtype
|
|
)
|
|
add_time_ids = add_time_ids.repeat(2, 1).to(dtype=img_sub_latents.dtype)
|
|
|
|
# predict the noise residual
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
|
with torch.amp.autocast(
|
|
device.type, dtype=latents.dtype, enabled=latents.dtype != self.unet.dtype
|
|
):
|
|
noise_pred = self.unet(
|
|
scaled_latent_model_input,
|
|
t,
|
|
encoder_hidden_states=prompt_embeds,
|
|
timestep_cond=timestep_cond,
|
|
cross_attention_kwargs=self.cross_attention_kwargs,
|
|
input_embedding=img_sub_latents,
|
|
add_sample=add_sample,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# perform guidance
|
|
if self.do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
|
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
|
noise_pred = rescale_noise_cfg(
|
|
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
|
|
)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents_dtype = sub_latents.dtype
|
|
sub_latents = self.scheduler.step(
|
|
noise_pred, t, sub_latents, **extra_step_kwargs, return_dict=False
|
|
)[0]
|
|
|
|
views_scheduler_status[sub_num] = copy.deepcopy(self.scheduler.__dict__)
|
|
concat_grid.append(sub_latents)
|
|
if latents.dtype != sub_latents:
|
|
if torch.backends.mps.is_available():
|
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
sub_latents = sub_latents.to(latents_dtype)
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
|
negative_pooled_prompt_embeds = callback_outputs.pop(
|
|
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
|
)
|
|
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
|
|
|
# call the callback, if provided
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
if callback is not None and i % callback_steps == 0:
|
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
|
callback(step_idx, t, latents)
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
latents = self.tlc_vae_latents.grids_inverse(torch.cat(concat_grid, dim=0)).to(sub_latents.dtype)
|
|
|
|
if not output_type == "latent":
|
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
|
|
|
if needs_upcasting:
|
|
self.upcast_vae()
|
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
|
elif latents.dtype != self.vae.dtype:
|
|
if torch.backends.mps.is_available():
|
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
self.vae = self.vae.to(latents.dtype)
|
|
|
|
# unscale/denormalize the latents
|
|
# denormalize with the mean and std if available and not None
|
|
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
|
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
|
if has_latents_mean and has_latents_std:
|
|
latents_mean = (
|
|
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents_std = (
|
|
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
|
else:
|
|
latents = latents / self.vae.config.scaling_factor
|
|
|
|
image = self.vae.decode(latents, return_dict=False)[0]
|
|
|
|
# cast back to fp16 if needed
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float16)
|
|
else:
|
|
image = latents
|
|
|
|
if not output_type == "latent":
|
|
# apply watermark if available
|
|
if self.watermark is not None:
|
|
image = self.watermark.apply_watermark(image)
|
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
# Offload all models
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (image,)
|
|
|
|
return StableDiffusionXLPipelineOutput(images=image)
|