1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Community] Support StableDiffusionCanvasPipeline (#3590)

* added StableDiffusionCanvasPipeline pipeline

* Added utils codes to pipe_utils file.

* make style

* delete mixture.py and Text2ImageRegion class

* make style

* Added the codes to the readme.md file.

* Moved functions from pipeline_utils to mix_canvas
This commit is contained in:
Kadir Nar
2023-06-07 19:43:33 +03:00
committed by GitHub
parent 803d653748
commit cd6186907c
3 changed files with 539 additions and 403 deletions

View File

@@ -1601,7 +1601,7 @@ pipe_images = mixing_pipeline(
![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png)
### Stable Diffusion Mixture
### Stable Diffusion Mixture Tiling
This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
@@ -1672,4 +1672,38 @@ mask_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "a mecha robot sitting on a bench"
image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0]
image.save('tensorrt_inpaint_mecha_robot.png')
```
```
### Stable Diffusion Mixture Canvas
This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
```python
from PIL import Image
from diffusers import LMSDiscreteScheduler, DiffusionPipeline
from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image
# Load and preprocess guide image
iic_image = preprocess_image(Image.open("input_image.png").convert("RGB"))
# Creater scheduler and model (similar to StableDiffusionPipeline)
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas")
pipeline.to("cuda")
# Mixture of Diffusers generation
output = pipeline(
canvas_height=800,
canvas_width=352,
regions=[
Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model,  textured, chiaroscuro, professional make-up, realistic, figure in frame, "),
Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),
],
num_inference_steps=100,
seed=5525475061,
)["images"][0]
```
![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png)
![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png)

View File

@@ -1,401 +0,0 @@
import inspect
from copy import deepcopy
from enum import Enum
from typing import List, Optional, Tuple, Union
import torch
from ligo.segments import segment
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import LMSDiscreteScheduler
>>> from mixdiff import StableDiffusionTilingPipeline
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
>>> pipeline = StableDiffusionTilingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler)
>>> pipeline.to("cuda:0")
>>> image = pipeline(
>>> prompt=[[
>>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
>>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
>>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
>>> ]],
>>> tile_height=640,
>>> tile_width=640,
>>> tile_row_overlap=0,
>>> tile_col_overlap=256,
>>> guidance_scale=8,
>>> seed=7178915308,
>>> num_inference_steps=50,
>>> )["images"][0]
```
"""
def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
"""Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
Returns a tuple with:
- Starting coordinates of rows in pixel space
- Ending coordinates of rows in pixel space
- Starting coordinates of columns in pixel space
- Ending coordinates of columns in pixel space
"""
px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
px_row_end = px_row_init + tile_height
px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
px_col_end = px_col_init + tile_width
return px_row_init, px_row_end, px_col_init, px_col_end
def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
"""Translates coordinates in pixel space to coordinates in latent space"""
return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
"""Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
Returns a tuple with:
- Starting coordinates of rows in latent space
- Ending coordinates of rows in latent space
- Starting coordinates of columns in latent space
- Ending coordinates of columns in latent space
"""
px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
def _tile2latent_exclusive_indices(
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns
):
"""Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
Returns a tuple with:
- Starting coordinates of rows in latent space
- Ending coordinates of rows in latent space
- Starting coordinates of columns in latent space
- Ending coordinates of columns in latent space
"""
row_init, row_end, col_init, col_end = _tile2latent_indices(
tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
row_segment = segment(row_init, row_end)
col_segment = segment(col_init, col_end)
# Iterate over the rest of tiles, clipping the region for the current tile
for row in range(rows):
for column in range(columns):
if row != tile_row and column != tile_col:
clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(
row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
row_segment = row_segment - segment(clip_row_init, clip_row_end)
col_segment = col_segment - segment(clip_col_init, clip_col_end)
# return row_init, row_end, col_init, col_end
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
class StableDiffusionExtrasMixin:
"""Mixin providing additional convenience method to Stable Diffusion pipelines"""
def decode_latents(self, latents, cpu_vae=False):
"""Decodes a given array of latents into pixel space"""
# scale and decode the image latents with vae
if cpu_vae:
lat = deepcopy(latents).cpu()
vae = deepcopy(self.vae).cpu()
else:
lat = latents
vae = self.vae
lat = 1 / 0.18215 * lat
image = vae.decode(lat).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return self.numpy_to_pil(image)
class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
class SeedTilesMode(Enum):
"""Modes in which the latents of a particular tile can be re-seeded"""
FULL = "full"
EXCLUSIVE = "exclusive"
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[List[str]]],
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
seed: Optional[int] = None,
tile_height: Optional[int] = 512,
tile_width: Optional[int] = 512,
tile_row_overlap: Optional[int] = 256,
tile_col_overlap: Optional[int] = 256,
guidance_scale_tiles: Optional[List[List[float]]] = None,
seed_tiles: Optional[List[List[int]]] = None,
seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
cpu_vae: Optional[bool] = False,
):
r"""
Function to run the diffusion pipeline with tiling support.
Args:
prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure.
num_inference_steps: number of diffusions steps.
guidance_scale: classifier-free guidance.
seed: general random seed to initialize latents.
tile_height: height in pixels of each grid tile.
tile_width: width in pixels of each grid tile.
tile_row_overlap: number of overlap pixels between tiles in consecutive rows.
tile_col_overlap: number of overlap pixels between tiles in consecutive columns.
guidance_scale_tiles: specific weights for classifier-free guidance in each tile.
guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used.
seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter.
seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden.
seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles.
cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues.
Examples:
Returns:
A PIL image with the generated image.
"""
if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
grid_rows = len(prompt)
grid_cols = len(prompt[0])
if not all(len(row) == grid_cols for row in prompt):
raise ValueError("All prompt rows must have the same number of prompt columns")
if not isinstance(seed_tiles_mode, str) and (
not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
):
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
if isinstance(seed_tiles_mode, str):
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
modes = [mode.value for mode in self.SeedTilesMode]
if any(mode not in modes for row in seed_tiles_mode for mode in row):
raise ValueError(f"Seed tiles mode must be one of {modes}")
if seed_reroll_regions is None:
seed_reroll_regions = []
batch_size = 1
# create original noisy latents using the timesteps
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
generator = torch.Generator("cuda").manual_seed(seed)
latents = torch.randn(latents_shape, generator=generator, device=self.device)
# overwrite latents for specific tiles if provided
if seed_tiles is not None:
for row in range(grid_rows):
for col in range(grid_cols):
if (seed_tile := seed_tiles[row][col]) is not None:
mode = seed_tiles_mode[row][col]
if mode == self.SeedTilesMode.FULL.value:
row_init, row_end, col_init, col_end = _tile2latent_indices(
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
else:
row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(
row,
col,
tile_width,
tile_height,
tile_row_overlap,
tile_col_overlap,
grid_rows,
grid_cols,
)
tile_generator = torch.Generator("cuda").manual_seed(seed_tile)
tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
tile_shape, generator=tile_generator, device=self.device
)
# overwrite again for seed reroll regions
for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
row_init, row_end, col_init, col_end = _pixel2latent_indices(
row_init, row_end, col_init, col_end
) # to latent space coordinates
reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll)
region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
region_shape, generator=reroll_generator, device=self.device
)
# Prepare scheduler
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# get prompts text embeddings
text_input = [
[
self.tokenizer(
col,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
for col in row
]
for row in prompt
]
text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
for i in range(grid_rows):
for j in range(grid_cols):
max_length = text_input[i][j].input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]])
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# Mask for tile weights strenght
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)
# Diffusion timesteps
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
# Diffuse each tile
noise_preds = []
for row in range(grid_rows):
noise_preds_row = []
for col in range(grid_cols):
px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[
"sample"
]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guidance = (
guidance_scale
if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None
else guidance_scale_tiles[row][col]
)
noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
noise_preds_row.append(noise_pred_tile)
noise_preds.append(noise_preds_row)
# Stitch noise predictions for all tiles
noise_pred = torch.zeros(latents.shape, device=self.device)
contributors = torch.zeros(latents.shape, device=self.device)
# Add each tile contribution to overall latents
for row in range(grid_rows):
for col in range(grid_cols):
px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
)
noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
noise_preds[row][col] * tile_weights
)
contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
# Average overlapping areas with more than 1 contributor
noise_pred /= contributors
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
image = self.decode_latents(latents, cpu_vae)
return {"images": image}
def _gaussian_weights(self, tile_width, tile_height, nbatches):
"""Generates a gaussian mask of weights for tile contributions"""
import numpy as np
from numpy import exp, pi, sqrt
latent_width = tile_width // 8
latent_height = tile_height // 8
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
for x in range(latent_width)
]
midpoint = latent_height / 2
y_probs = [
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
for y in range(latent_height)
]
weights = np.outer(y_probs, x_probs)
return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))

View File

@@ -0,0 +1,503 @@
import re
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import List, Optional, Union
import numpy as np
import torch
from numpy import exp, pi, sqrt
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
def preprocess_image(image):
from PIL import Image
"""Preprocess an input image
Same as
https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44
"""
w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
@dataclass
class CanvasRegion:
"""Class defining a rectangular region in the canvas"""
row_init: int # Region starting row in pixel space (included)
row_end: int # Region end row in pixel space (not included)
col_init: int # Region starting column in pixel space (included)
col_end: int # Region end column in pixel space (not included)
region_seed: int = None # Seed for random operations in this region
noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents
def __post_init__(self):
# Initialize arguments if not specified
if self.region_seed is None:
self.region_seed = np.random.randint(9999999999)
# Check coordinates are non-negative
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
if coord < 0:
raise ValueError(
f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})"
)
# Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
if coord // 8 != coord / 8:
raise ValueError(
f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})"
)
# Check noise eps is non-negative
if self.noise_eps < 0:
raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}")
# Compute coordinates for this region in latent space
self.latent_row_init = self.row_init // 8
self.latent_row_end = self.row_end // 8
self.latent_col_init = self.col_init // 8
self.latent_col_end = self.col_end // 8
@property
def width(self):
return self.col_end - self.col_init
@property
def height(self):
return self.row_end - self.row_init
def get_region_generator(self, device="cpu"):
"""Creates a torch.Generator based on the random seed of this region"""
# Initialize region generator
return torch.Generator(device).manual_seed(self.region_seed)
@property
def __dict__(self):
return asdict(self)
class MaskModes(Enum):
"""Modes in which the influence of diffuser is masked"""
CONSTANT = "constant"
GAUSSIAN = "gaussian"
QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics)
@dataclass
class DiffusionRegion(CanvasRegion):
"""Abstract class defining a region where some class of diffusion process is acting"""
pass
@dataclass
class Text2ImageRegion(DiffusionRegion):
"""Class defining a region where a text guided diffusion process is acting"""
prompt: str = "" # Text prompt guiding the diffuser in this region
guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize
mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region
mask_weight: float = 1.0 # Global weights multiplier of the mask
tokenized_prompt = None # Tokenized prompt
encoded_prompt = None # Encoded prompt
def __post_init__(self):
super().__post_init__()
# Mask weight cannot be negative
if self.mask_weight < 0:
raise ValueError(
f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}"
)
# Mask type must be an actual known mask
if self.mask_type not in [e.value for e in MaskModes]:
raise ValueError(
f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})"
)
# Randomize arguments if given as None
if self.guidance_scale is None:
self.guidance_scale = np.random.randint(5, 30)
# Clean prompt
self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ")
def tokenize_prompt(self, tokenizer):
"""Tokenizes the prompt for this diffusion region using a given tokenizer"""
self.tokenized_prompt = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
def encode_prompt(self, text_encoder, device):
"""Encodes the previously tokenized prompt for this diffusion region using a given encoder"""
assert self.tokenized_prompt is not None, ValueError(
"Prompt in diffusion region must be tokenized before encoding"
)
self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]
@dataclass
class Image2ImageRegion(DiffusionRegion):
"""Class defining a region where an image guided diffusion process is acting"""
reference_image: torch.FloatTensor = None
strength: float = 0.8 # Strength of the image
def __post_init__(self):
super().__post_init__()
if self.reference_image is None:
raise ValueError("Must provide a reference image when creating an Image2ImageRegion")
if self.strength < 0 or self.strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}")
# Rescale image to region shape
self.reference_image = resize(self.reference_image, size=[self.height, self.width])
def encode_reference_image(self, encoder, device, generator, cpu_vae=False):
"""Encodes the reference image for this Image2Image region into the latent space"""
# Place encoder in CPU or not following the parameter cpu_vae
if cpu_vae:
# Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome
self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)
else:
self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(
generator=generator
)
self.reference_latents = 0.18215 * self.reference_latents
@property
def __dict__(self):
# This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON
# Get all basic fields from parent class
super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}
# Pack other fields
return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength}
class RerollModes(Enum):
"""Modes in which the reroll regions operate"""
RESET = "reset" # Completely reset the random noise in the region
EPSILON = "epsilon" # Alter slightly the latents in the region
@dataclass
class RerollRegion(CanvasRegion):
"""Class defining a rectangular canvas region in which initial latent noise will be rerolled"""
reroll_mode: RerollModes = RerollModes.RESET.value
@dataclass
class MaskWeightsBuilder:
"""Auxiliary class to compute a tensor of weights for a given diffusion region"""
latent_space_dim: int # Size of the U-net latent space
nbatch: int = 1 # Batch size in the U-net
def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Computes a tensor of weights for a given diffusion region"""
MASK_BUILDERS = {
MaskModes.CONSTANT.value: self._constant_weights,
MaskModes.GAUSSIAN.value: self._gaussian_weights,
MaskModes.QUARTIC.value: self._quartic_weights,
}
return MASK_BUILDERS[region.mask_type](region)
def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Computes a tensor of constant for a given diffusion region"""
latent_width = region.latent_col_end - region.latent_col_init
latent_height = region.latent_row_end - region.latent_row_init
return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight
def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Generates a gaussian mask of weights for tile contributions"""
latent_width = region.latent_col_end - region.latent_col_init
latent_height = region.latent_row_end - region.latent_row_init
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
for x in range(latent_width)
]
midpoint = (latent_height - 1) / 2
y_probs = [
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
for y in range(latent_height)
]
weights = np.outer(y_probs, x_probs) * region.mask_weight
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Generates a quartic mask of weights for tile contributions
The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
"""
quartic_constant = 15.0 / 16.0
support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (
region.latent_col_end - region.latent_col_init - 1
) * 1.99 - (1.99 / 2.0)
x_probs = quartic_constant * np.square(1 - np.square(support))
support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (
region.latent_row_end - region.latent_row_init - 1
) * 1.99 - (1.99 / 2.0)
y_probs = quartic_constant * np.square(1 - np.square(support))
weights = np.outer(y_probs, x_probs) * region.mask_weight
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
class StableDiffusionCanvasPipeline(DiffusionPipeline):
"""Stable Diffusion pipeline that mixes several diffusers in the same canvas"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def decode_latents(self, latents, cpu_vae=False):
"""Decodes a given array of latents into pixel space"""
# scale and decode the image latents with vae
if cpu_vae:
lat = deepcopy(latents).cpu()
vae = deepcopy(self.vae).cpu()
else:
lat = latents
vae = self.vae
lat = 1 / 0.18215 * lat
image = vae.decode(lat).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return self.numpy_to_pil(image)
def get_latest_timestep_img2img(self, num_inference_steps, strength):
"""Finds the latest timesteps where an img2img strength does not impose latents anymore"""
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * (1 - strength)) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1)
latest_timestep = self.scheduler.timesteps[t_start]
return latest_timestep
@torch.no_grad()
def __call__(
self,
canvas_height: int,
canvas_width: int,
regions: List[DiffusionRegion],
num_inference_steps: Optional[int] = 50,
seed: Optional[int] = 12345,
reroll_regions: Optional[List[RerollRegion]] = None,
cpu_vae: Optional[bool] = False,
decode_steps: Optional[bool] = False,
):
if reroll_regions is None:
reroll_regions = []
batch_size = 1
if decode_steps:
steps_images = []
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
# Split diffusion regions by their kind
text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]
image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]
# Prepare text embeddings
for region in text2image_regions:
region.tokenize_prompt(self.tokenizer)
region.encode_prompt(self.text_encoder, self.device)
# Create original noisy latents using the timesteps
latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)
generator = torch.Generator(self.device).manual_seed(seed)
init_noise = torch.randn(latents_shape, generator=generator, device=self.device)
# Reset latents in seed reroll regions, if requested
for region in reroll_regions:
if region.reroll_mode == RerollModes.RESET.value:
region_shape = (
latents_shape[0],
latents_shape[1],
region.latent_row_end - region.latent_row_init,
region.latent_col_end - region.latent_col_init,
)
init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)
# Apply epsilon noise to regions: first diffusion regions, then reroll regions
all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]
for region in all_eps_rerolls:
if region.noise_eps > 0:
region_noise = init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
eps_noise = (
torch.randn(
region_noise.shape, generator=region.get_region_generator(self.device), device=self.device
)
* region.noise_eps
)
init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += eps_noise
# scale the initial noise by the standard deviation required by the scheduler
latents = init_noise * self.scheduler.init_noise_sigma
# Get unconditional embeddings for classifier free guidance in text2image regions
for region in text2image_regions:
max_length = region.tokenized_prompt.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])
# Prepare image latents
for region in image2image_regions:
region.encode_reference_image(self.vae, device=self.device, generator=generator)
# Prepare mask of weights for each region
mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)
mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]
# Diffusion timesteps
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
# Diffuse each region
noise_preds_regions = []
# text2image regions
for region in text2image_regions:
region_latents = latents[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([region_latents] * 2)
# scale model input following scheduler rules
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"]
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_preds_regions.append(noise_pred_region)
# Merge noise predictions for all tiles
noise_pred = torch.zeros(latents.shape, device=self.device)
contributors = torch.zeros(latents.shape, device=self.device)
# Add each tile contribution to overall latents
for region, noise_pred_region, mask_weights_region in zip(
text2image_regions, noise_preds_regions, mask_weights
):
noise_pred[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += (
noise_pred_region * mask_weights_region
)
contributors[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += mask_weights_region
# Average overlapping areas with more than 1 contributor
noise_pred /= contributors
noise_pred = torch.nan_to_num(
noise_pred
) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# Image2Image regions: override latents generated by the scheduler
for region in image2image_regions:
influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)
# Only override in the timesteps before the last influence step of the image (given by its strength)
if t > influence_step:
timestep = t.repeat(batch_size)
region_init_noise = init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)
latents[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] = region_latents
if decode_steps:
steps_images.append(self.decode_latents(latents, cpu_vae))
# scale and decode the image latents with vae
image = self.decode_latents(latents, cpu_vae)
output = {"images": image}
if decode_steps:
output = {**output, "steps_images": steps_images}
return output