1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refactor prior_transformer

adding conversion script

add pipeline

add step_index from pipeline, + remove permute

add zero pad token

remove copy from statement for betas_for_alpha_bar function
This commit is contained in:
yiyixuxu
2023-05-11 21:14:30 +00:00
parent 4f14b36329
commit 20e5be74d8
8 changed files with 804 additions and 30 deletions

View File

@@ -0,0 +1,324 @@
import argparse
import tempfile
import torch
from accelerate import load_checkpoint_and_dispatch
from diffusers.models.prior_transformer import PriorTransformer
"""
Example - From the diffusers root directory:
Download weights:
```sh
$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt"
```
Convert the model:
```sh
$ python scripts/convert_shap_e_to_diffusers.py \
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
--dump_path /home/yiyi_huggingface_co/model_repo/shape \
--debug prior
```
"""
# prior
PRIOR_ORIGINAL_PREFIX = "wrapped"
# Uses default arguments
PRIOR_CONFIG = {
"num_attention_heads": 16,
"attention_head_dim": 1024 // 16,
"num_layers": 24,
"embedding_dim": 1024,
"num_embeddings": 1024,
"additional_embeddings": 0,
"act_fn": "gelu",
"time_embed_dim": 1024 * 4,
"clip_embedding_dim": 768,
"out_dim": 1024 * 2,
"has_pre_norm": True,
"has_encoder_hidden_states_proj": False,
"has_prd_embedding": False,
"has_post_process": False,
}
def prior_model_from_original_config():
model = PriorTransformer(**PRIOR_CONFIG)
return model
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
diffusers_checkpoint.update(
{
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
}
)
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
diffusers_checkpoint.update(
{
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
}
)
# <original>.clip_img_proj -> <diffusers>.proj_in
diffusers_checkpoint.update(
{
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"],
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"],
}
)
# <original>.text_emb_proj -> <diffusers>.embedding_proj
diffusers_checkpoint.update(
{
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"],
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"],
}
)
# <original>.positional_embedding -> <diffusers>.positional_embedding
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]})
# <original>.ln_pre -> <diffusers>.norm_in
diffusers_checkpoint.update(
{
"norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"],
"norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"],
}
)
# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
for idx in range(len(model.transformer_blocks)):
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
# <original>.attn -> <diffusers>.attn1
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
original_attention_prefix = f"{original_transformer_prefix}.attn"
diffusers_checkpoint.update(
prior_attention_to_diffusers(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
original_attention_prefix=original_attention_prefix,
attention_head_dim=model.attention_head_dim,
)
)
# <original>.mlp -> <diffusers>.ff
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
original_ff_prefix = f"{original_transformer_prefix}.mlp"
diffusers_checkpoint.update(
prior_ff_to_diffusers(
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
)
)
# <original>.ln_1 -> <diffusers>.norm1
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
f"{original_transformer_prefix}.ln_1.weight"
],
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
}
)
# <original>.ln_2 -> <diffusers>.norm3
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
f"{original_transformer_prefix}.ln_2.weight"
],
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
}
)
# <original>.final_ln -> <diffusers>.norm_out
diffusers_checkpoint.update(
{
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"],
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"],
}
)
# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
diffusers_checkpoint.update(
{
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"],
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"],
}
)
return diffusers_checkpoint
def prior_attention_to_diffusers(
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
):
diffusers_checkpoint = {}
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
split=3,
chunk_size=attention_head_dim,
)
diffusers_checkpoint.update(
{
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
}
)
# <original>.c_proj -> <diffusers>.to_out.0
diffusers_checkpoint.update(
{
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
}
)
return diffusers_checkpoint
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
diffusers_checkpoint = {
# <original>.c_fc -> <diffusers>.net.0.proj
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
# <original>.c_proj -> <diffusers>.net.2
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
}
return diffusers_checkpoint
# done prior
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
def split_attentions(*, weight, bias, split, chunk_size):
weights = [None] * split
biases = [None] * split
weights_biases_idx = 0
for starting_row_index in range(0, weight.shape[0], chunk_size):
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
weight_rows = weight[row_indices, :]
bias_rows = bias[row_indices]
if weights[weights_biases_idx] is None:
assert weights[weights_biases_idx] is None
weights[weights_biases_idx] = weight_rows
biases[weights_biases_idx] = bias_rows
else:
assert weights[weights_biases_idx] is not None
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
weights_biases_idx = (weights_biases_idx + 1) % split
return weights, biases
# done unet utils
# Driver functions
def prior(*, args, checkpoint_map_location):
print("loading prior")
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
prior_model = prior_model_from_original_config()
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
del prior_checkpoint
load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
print("done loading prior")
return prior_model
def load_checkpoint_to_model(checkpoint, model, strict=False):
with tempfile.NamedTemporaryFile() as file:
torch.save(checkpoint, file.name)
del checkpoint
if strict:
model.load_state_dict(torch.load(file.name), strict=True)
else:
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--prior_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the prior checkpoint to convert.",
)
parser.add_argument(
"--checkpoint_load_device",
default="cpu",
type=str,
required=False,
help="The device passed to `map_location` when loading checkpoints.",
)
parser.add_argument(
"--debug",
default=None,
type=str,
required=False,
help="Only run a specific stage of the convert script. Used for debugging",
)
args = parser.parse_args()
print(f"loading checkpoints to {args.checkpoint_load_device}")
checkpoint_map_location = torch.device(args.checkpoint_load_device)
if args.debug is not None:
print(f"debug: only executing {args.debug}")
if args.debug is None:
print("YiYi TO-DO")
elif args.debug == "prior":
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
else:
raise ValueError(f"unknown debug value : {args.debug}")

View File

@@ -137,6 +137,7 @@ else:
LDMTextToImagePipeline,
PaintByExamplePipeline,
SemanticStableDiffusionPipeline,
ShapEPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,

View File

@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from typing import Optional, Union
@@ -58,6 +59,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
act_fn: str = "silu",
has_pre_norm: bool = False,
has_encoder_hidden_states_proj: bool = True,
has_prd_embedding: bool = True,
has_post_process: bool = True,
time_embed_dim: Optional[int] = None,
clip_embedding_dim: Optional[int] = None,
out_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -65,17 +74,33 @@ class PriorTransformer(ModelMixin, ConfigMixin):
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings
if time_embed_dim is None:
time_embed_dim = inner_dim
if clip_embedding_dim is None:
clip_embedding_dim = embedding_dim
if out_dim is None:
out_dim = embedding_dim
self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=act_fn)
self.proj_in = nn.Linear(embedding_dim, inner_dim)
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
self.embedding_proj = nn.Linear(clip_embedding_dim, inner_dim)
if has_encoder_hidden_states_proj:
self.encoder_hidden_states_proj = nn.Linear(clip_embedding_dim, inner_dim)
else:
self.encoder_hidden_states_proj = None
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
if has_prd_embedding:
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
else:
self.prd_embedding = None
self.transformer_blocks = nn.ModuleList(
[
@@ -91,8 +116,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
]
)
if has_pre_norm:
self.norm_in = nn.LayerNorm(inner_dim)
else:
self.norm_in = None
self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, out_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
@@ -100,16 +131,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
if has_post_process:
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embedding_dim))
else:
self.clip_mean = None
self.clip_std = None
def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
@@ -152,23 +186,49 @@ class PriorTransformer(ModelMixin, ConfigMixin):
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
# Rescale the features to have unit variance
# YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
if self.clip_mean is None:
proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding
proj_embeddings = self.embedding_proj(proj_embedding)
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
hidden_states = self.proj_in(hidden_states)
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
tokens = []
if encoder_hidden_states is not None:
tokens.append(encoder_hidden_states)
tokens = tokens + [
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :] if len(hidden_states.shape) == 2 else hidden_states,
]
if self.prd_embedding is not None:
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
tokens.append(prd_embedding)
hidden_states = torch.cat(
[
encoder_hidden_states,
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :],
prd_embedding,
],
tokens,
dim=1,
)
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
additional_embeddings = 2 + (encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0)
if positional_embeddings.shape[1] < hidden_states.shape[1]:
positional_embeddings = F.pad(
positional_embeddings,
(0, 0, additional_embeddings, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0),
value=0.0,
)
hidden_states = hidden_states + positional_embeddings
if attention_mask is not None:
@@ -177,11 +237,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -1]
if self.prd_embedding is not None:
hidden_states = hidden_states[:, -1]
else:
hidden_states = hidden_states[:, additional_embeddings:]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
if not return_dict:

View File

@@ -66,6 +66,7 @@ else:
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEPipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,

View File

@@ -0,0 +1,15 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .pipeline_shap_e import ShapEPipeline

View File

@@ -0,0 +1,311 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
is_accelerate_available,
logging,
randn_tensor,
replace_example_docstring,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
```
"""
@dataclass
class ShapEPipelineOutput(BaseOutput):
"""
Output class for ShapEPipeline.
Args:
images (`torch.FloatTensor`)
3D latent representation
"""
latents: Union[torch.FloatTensor, np.ndarray]
class ShapEPipeline(DiffusionPipeline):
"""
Pipeline for generating latent representation of a 3D asset with Shap.E
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
prior ([`PriorTransformer`]):
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
"""
def __init__(
self,
prior: PriorTransformer,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
scheduler: HeunDiscreteScheduler,
):
super().__init__()
self.register_modules(
prior=prior,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
)
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [
self.text_encoder,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
return self.device
for module in self.text_encoder.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
):
len(prompt) if isinstance(prompt, list) else 1
# YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
self.tokenizer.pad_token_id = 0
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# in Shap-E it normalize the prompt_embeds and then later rescale it, not sure why
# YiYi TO-DO: move rescale out of prior_transformer and apply it here
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
num_inference_steps: int = 25,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0,
sigma_min: float = 1e-3,
sigma_max: float = 160.0,
output_type: Optional[str] = "pt", # pt only
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`ShapEPipelineOutput`] or `tuple`
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
# prior
self.scheduler.set_timesteps(
num_inference_steps, device=device, sigma_min=sigma_min, sigma_max=sigma_max, use_karras_sigmas=True
)
timesteps = self.scheduler.timesteps
num_embeddings = self.prior.config.num_embeddings
embedding_dim = self.prior.config.embedding_dim
latents = self.prepare_latents(
(batch_size, num_embeddings * embedding_dim),
prompt_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
# for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.prior(
scaled_model_input,
timestep=t,
proj_embedding=prompt_embeds,
).predicted_image_embedding
# remove the variance
noise_pred, _ = noise_pred.split(
scaled_model_input.shape[2], dim=2
) # batch_size, num_embeddings, embedding_dim
# clip between -1 and 1
noise_pred = noise_pred.clamp(-1, 1)
if do_classifier_free_guidance is not None:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
timestep=t,
sample=latents,
step_index=i,
).prev_sample
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
if output_type == "np":
latents = latents.cpu().numpy()
if not return_dict:
return latents
return ShapEPipelineOutput(latents)

View File

@@ -22,8 +22,11 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_bar_fn=None,
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -44,11 +47,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
if alpha_bar_fn is None:
alpha_bar_fn = alpha_bar
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -106,6 +112,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "exp":
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_bar_fn=lambda t: math.exp(t * -12.0))
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
@@ -152,6 +160,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
use_karras_sigmas: Optional[bool] = None, # overwrite the self.config.use_karras_sigma
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -166,15 +177,25 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
if sigma_min is not None and sigma_max is not None:
sigmas = torch.tensor([sigma_max, sigma_min])
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
else:
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
if self.use_karras_sigmas:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if use_karras_sigmas is None:
use_karras_sigmas = self.use_karras_sigmas
if use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.beta_schedule == "exp":
timesteps = np.array([self._sigma_to_t_yiyi(sigma) for sigma in sigmas])
else:
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -220,6 +241,22 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy
def _sigma_to_t_yiyi(self, sigma):
alpha_cumprod = 1.0 / (sigma**2 + 1)
if alpha_cumprod > self.alphas_cumprod[0]:
return 0
elif alpha_cumprod <= self.alphas_cumprod[-1]:
return len(self.alphas_cumprod) - 1
else:
from scipy import interpolate
timestep = interpolate.interp1d(self.alphas_cumprod, np.arange(0, len(self.alphas_cumprod)))(
alpha_cumprod
) # yiyi testing, origin implementation
return int(timestep)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
@@ -244,6 +281,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True,
step_index: Optional[int] = None,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
@@ -258,7 +296,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if step_index is None:
step_index = self.index_for_timestep(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
@@ -284,7 +323,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample / (sigma_input**2 + 1)
)
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"

View File

@@ -227,6 +227,21 @@ class KandinskyPriorPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ShapEPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]