mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community] Add SDE Drag pipeline (#6105)
* Add community pipeline: sde_drag.py * Update README.md * Update README.md Update example code and visual example * Update sde_drag.py Update code example.
This commit is contained in:
@@ -48,6 +48,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
|
||||
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
@@ -2986,3 +2987,42 @@ def image_grid(imgs, save_path=None):
|
||||
image_grid(images, save_path="./outputs/")
|
||||
```
|
||||

|
||||
|
||||
### SDE Drag pipeline
|
||||
|
||||
This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.
|
||||
|
||||

|
||||
|
||||
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more infomation.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
# Load the pipeline
|
||||
model_path = "runwayml/stable-diffusion-v1-5"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
pipe.to('cuda')
|
||||
|
||||
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
# If not training LoRA, please avoid using torch.float16
|
||||
# pipe.to(torch.float16)
|
||||
|
||||
# Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
prompt = "prompt of the image"
|
||||
image = PIL.Image.open('/path/to/image')
|
||||
mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
source_points = [[123, 456]]
|
||||
target_points = [[234, 567]]
|
||||
|
||||
# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
pipe.train_lora(prompt, image)
|
||||
|
||||
output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
output_image = PIL.Image.fromarray(output)
|
||||
output_image.save("./output.png")
|
||||
|
||||
```
|
||||
|
||||
594
examples/community/sde_drag.py
Normal file
594
examples/community/sde_drag.py
Normal file
@@ -0,0 +1,594 @@
|
||||
import math
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
class SdeDragPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410.
|
||||
Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use
|
||||
[`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image,
|
||||
source_points: List[List[int]],
|
||||
target_points: List[List[int]],
|
||||
t0: Optional[float] = 0.6,
|
||||
steps: Optional[int] = 200,
|
||||
step_size: Optional[int] = 2,
|
||||
image_scale: Optional[float] = 0.3,
|
||||
adapt_radius: Optional[int] = 5,
|
||||
min_lora_scale: Optional[float] = 0.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for image editing.
|
||||
Args:
|
||||
prompt (`str`, *required*):
|
||||
The prompt to guide the image editing.
|
||||
image (`PIL.Image.Image`, *required*):
|
||||
Which will be edited, parts of the image will be masked out with `mask_image` and edited
|
||||
according to `prompt`.
|
||||
mask_image (`PIL.Image.Image`, *required*):
|
||||
To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved.
|
||||
source_points (`List[List[int]]`, *required*):
|
||||
Used to mark the starting positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
target_points (`List[List[int]]`, *required*):
|
||||
Used to mark the target positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
t0 (`float`, *optional*, defaults to 0.6):
|
||||
The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images
|
||||
and vice versa.
|
||||
steps (`int`, *optional*, defaults to 200):
|
||||
The number of sampling iterations.
|
||||
step_size (`int`, *optional*, defaults to 2):
|
||||
The drag diatance of each drag step.
|
||||
image_scale (`float`, *optional*, defaults to 0.3):
|
||||
To avoid duplicating the content, use image_scale to perturbs the source.
|
||||
adapt_radius (`int`, *optional*, defaults to 5):
|
||||
The size of the region for copy and paste operations during each step of the drag process.
|
||||
min_lora_scale (`float`, *optional*, defaults to 0.5):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
min_lora_scale specifies the minimum LoRA scale during the image drag-editing process.
|
||||
generator ('torch.Generator', *optional*, defaults to None):
|
||||
To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html).
|
||||
Examples:
|
||||
```py
|
||||
>>> import PIL
|
||||
>>> import torch
|
||||
>>> from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
>>> # Load the pipeline
|
||||
>>> model_path = "runwayml/stable-diffusion-v1-5"
|
||||
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
>>> pipe.to('cuda')
|
||||
|
||||
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
>>> # If not training LoRA, please avoid using torch.float16
|
||||
>>> # pipe.to(torch.float16)
|
||||
|
||||
>>> # Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
>>> prompt = "prompt of the image"
|
||||
>>> image = PIL.Image.open('/path/to/image')
|
||||
>>> mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
>>> source_points = [[123, 456]]
|
||||
>>> target_points = [[234, 567]]
|
||||
|
||||
>>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
>>> pipe.train_lora(prompt, image)
|
||||
|
||||
>>> output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
>>> output_image = PIL.Image.fromarray(output)
|
||||
>>> output_image.save("./output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
self.scheduler.set_timesteps(steps)
|
||||
|
||||
noise_scale = (1 - image_scale**2) ** (0.5)
|
||||
|
||||
text_embeddings = self._get_text_embed(prompt)
|
||||
uncond_embeddings = self._get_text_embed([""])
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latent = self._get_img_latent(image)
|
||||
|
||||
mask = mask_image.resize((latent.shape[3], latent.shape[2]))
|
||||
mask = torch.tensor(np.array(mask))
|
||||
mask = mask.unsqueeze(0).expand_as(latent).to(self.device)
|
||||
|
||||
source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
|
||||
distance = target_points - source_points
|
||||
distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max()
|
||||
|
||||
if distance_norm_max <= step_size:
|
||||
drag_num = 1
|
||||
else:
|
||||
drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc")
|
||||
if (distance_norm_max / drag_num - step_size).abs() > (
|
||||
distance_norm_max / (drag_num + 1) - step_size
|
||||
).abs():
|
||||
drag_num += 1
|
||||
|
||||
latents = []
|
||||
for i in tqdm(range(int(drag_num)), desc="SDE Drag"):
|
||||
source_new = source_points + (i / drag_num * distance).to(torch.int)
|
||||
target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int)
|
||||
|
||||
latent, noises, hook_latents, lora_scales, cfg_scales = self._forward(
|
||||
latent, steps, t0, min_lora_scale, text_embeddings, generator
|
||||
)
|
||||
latent = self._copy_and_paste(
|
||||
latent,
|
||||
source_new,
|
||||
target_new,
|
||||
adapt_radius,
|
||||
latent.shape[2] - 1,
|
||||
latent.shape[3] - 1,
|
||||
image_scale,
|
||||
noise_scale,
|
||||
generator,
|
||||
)
|
||||
latent = self._backward(
|
||||
latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
)
|
||||
|
||||
latents.append(latent)
|
||||
|
||||
result_image = 1 / 0.18215 * latents[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
result_image = self.vae.decode(result_image).sample
|
||||
|
||||
result_image = (result_image / 2 + 0.5).clamp(0, 1)
|
||||
result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
result_image = (result_image * 255).astype(np.uint8)
|
||||
|
||||
return result_image
|
||||
|
||||
def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None):
|
||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16")
|
||||
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.unet.requires_grad_(False)
|
||||
|
||||
unet_lora_attn_procs = {}
|
||||
for name, attn_processor in self.unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.unet.config.block_out_channels[block_id]
|
||||
else:
|
||||
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
|
||||
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
||||
)
|
||||
|
||||
self.unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
||||
params_to_optimize = unet_lora_layers.parameters()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params_to_optimize,
|
||||
lr=2e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"constant",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=lora_step,
|
||||
num_cycles=1,
|
||||
power=1.0,
|
||||
)
|
||||
|
||||
unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
|
||||
optimizer = accelerator.prepare_optimizer(optimizer)
|
||||
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
|
||||
|
||||
with torch.no_grad():
|
||||
text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None)
|
||||
text_embedding = self._encode_prompt(
|
||||
text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
image = image_transforms(image).to(self.device, dtype=self.vae.dtype)
|
||||
image = image.unsqueeze(dim=0)
|
||||
latents_dist = self.vae.encode(image).latent_dist
|
||||
|
||||
for _ in tqdm(range(lora_step), desc="Train LoRA"):
|
||||
self.unet.train()
|
||||
model_input = latents_dist.sample() * self.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(
|
||||
model_input.size(),
|
||||
dtype=model_input.dtype,
|
||||
layout=model_input.layout,
|
||||
device=model_input.device,
|
||||
generator=generator,
|
||||
)
|
||||
bsz, channels, height, width = model_input.shape
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
|
||||
|
||||
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
with tempfile.TemporaryDirectory() as save_lora_dir:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=save_lora_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
self.unet.load_attn_procs(save_lora_dir)
|
||||
|
||||
def _tokenize_prompt(self, prompt, tokenizer_max_length=None):
|
||||
if tokenizer_max_length is not None:
|
||||
max_length = tokenizer_max_length
|
||||
else:
|
||||
max_length = self.tokenizer.model_max_length
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return text_inputs
|
||||
|
||||
def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False):
|
||||
text_input_ids = input_ids.to(self.device)
|
||||
|
||||
if text_encoder_use_attention_mask:
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_text_embed(self, prompt):
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return text_embeddings
|
||||
|
||||
def _copy_and_paste(
|
||||
self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator
|
||||
):
|
||||
def adaption_r(source, target, adapt_radius, max_height, max_width):
|
||||
r_x_lower = min(adapt_radius, source[0], target[0])
|
||||
r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0])
|
||||
r_y_lower = min(adapt_radius, source[1], target[1])
|
||||
r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1])
|
||||
return r_x_lower, r_x_upper, r_y_lower, r_y_upper
|
||||
|
||||
for source_, target_ in zip(source_new, target_new):
|
||||
r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r(
|
||||
source_, target_, adapt_radius, max_height, max_width
|
||||
)
|
||||
|
||||
source_feature = latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
].clone()
|
||||
|
||||
latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
] = image_scale * source_feature + noise_scale * torch.randn(
|
||||
latent.shape[0],
|
||||
4,
|
||||
r_y_lower + r_y_upper,
|
||||
r_x_lower + r_x_upper,
|
||||
device=self.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
latent[
|
||||
:, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper
|
||||
] = source_feature * 1.1
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_img_latent(self, image, height=None, weight=None):
|
||||
data = image.convert("RGB")
|
||||
if height is not None:
|
||||
data = data.resize((weight, height))
|
||||
transform = transforms.ToTensor()
|
||||
data = transform(data).unsqueeze(0)
|
||||
data = (data * 2.0) - 1.0
|
||||
data = data.to(self.device, dtype=self.vae.dtype)
|
||||
latent = self.vae.encode(data).latent_dist.sample()
|
||||
latent = 0.18215 * latent
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None):
|
||||
latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent
|
||||
text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1]
|
||||
|
||||
cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale}
|
||||
|
||||
with torch.no_grad():
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
elif guidance_scale == 1.0:
|
||||
noise_pred_text = noise_pred
|
||||
noise_pred_uncond = 0.0
|
||||
else:
|
||||
raise NotImplementedError(guidance_scale)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def _forward_sde(
|
||||
self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
initial_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
prev_timestep = timestep + num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep]
|
||||
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (
|
||||
0.5
|
||||
) * torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
sigma_t_prev = (
|
||||
eta
|
||||
* (1 - alpha_prod_t) ** (0.5)
|
||||
* (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5)
|
||||
)
|
||||
|
||||
pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps
|
||||
) / sigma_t_prev
|
||||
|
||||
return x_prev, noise
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
timestep,
|
||||
sample,
|
||||
guidance_scale,
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=False,
|
||||
noise=None,
|
||||
eta=1.0,
|
||||
lora_scale=None,
|
||||
generator=None,
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
final_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
prev_timestep = timestep - num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
sigma_t = (
|
||||
eta
|
||||
* ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5)
|
||||
* (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5)
|
||||
if sde
|
||||
else 0
|
||||
)
|
||||
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
if noise is None
|
||||
else noise
|
||||
)
|
||||
latent = (
|
||||
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise
|
||||
)
|
||||
|
||||
return latent
|
||||
|
||||
def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator):
|
||||
def scale_schedule(begin, end, n, length, type="linear"):
|
||||
if type == "constant":
|
||||
return end
|
||||
elif type == "linear":
|
||||
return begin + (end - begin) * n / length
|
||||
elif type == "cos":
|
||||
factor = (1 - math.cos(n * math.pi / length)) / 2
|
||||
return (1 - factor) * begin + factor * end
|
||||
else:
|
||||
raise NotImplementedError(type)
|
||||
|
||||
noises = []
|
||||
latents = []
|
||||
lora_scales = []
|
||||
cfg_scales = []
|
||||
latents.append(latent)
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1
|
||||
index = 1
|
||||
for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]):
|
||||
lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos")
|
||||
cfg_scale = scale_schedule(1, 3.0, index, length, type="linear")
|
||||
latent, noise = self._forward_sde(
|
||||
t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator
|
||||
)
|
||||
|
||||
noises.append(noise)
|
||||
latents.append(latent)
|
||||
lora_scales.append(lora_scale)
|
||||
cfg_scales.append(cfg_scale)
|
||||
index += 1
|
||||
return latent, noises, latents, lora_scales, cfg_scales
|
||||
|
||||
def _backward(
|
||||
self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
):
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
for t in self.scheduler.timesteps[t_begin - 1 : -1]:
|
||||
latent = self._sample(
|
||||
t,
|
||||
latent,
|
||||
cfg_scales.pop(),
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=True,
|
||||
noise=noises.pop(),
|
||||
lora_scale=lora_scales.pop(),
|
||||
generator=generator,
|
||||
)
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
return latent
|
||||
Reference in New Issue
Block a user