From 20e5be74d827f2225259f1c0e4e785cbcc98fb7c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 11 May 2023 21:14:30 +0000 Subject: [PATCH] refactor prior_transformer adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function --- scripts/convert_shap_e_to_diffusers.py | 324 ++++++++++++++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/prior_transformer.py | 106 +++++- src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/shap_e/__init__.py | 15 + .../pipelines/shap_e/pipeline_shap_e.py | 311 +++++++++++++++++ .../schedulers/scheduling_heun_discrete.py | 61 +++- .../dummy_torch_and_transformers_objects.py | 15 + 8 files changed, 804 insertions(+), 30 deletions(-) create mode 100644 scripts/convert_shap_e_to_diffusers.py create mode 100644 src/diffusers/pipelines/shap_e/__init__.py create mode 100644 src/diffusers/pipelines/shap_e/pipeline_shap_e.py diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py new file mode 100644 index 0000000000..4a159fc61f --- /dev/null +++ b/scripts/convert_shap_e_to_diffusers.py @@ -0,0 +1,324 @@ +import argparse +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch + +from diffusers.models.prior_transformer import PriorTransformer + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt" +``` + +Convert the model: +```sh +$ python scripts/convert_shap_e_to_diffusers.py \ + --prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \ + --dump_path /home/yiyi_huggingface_co/model_repo/shape \ + --debug prior +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "wrapped" + +# Uses default arguments +PRIOR_CONFIG = { + "num_attention_heads": 16, + "attention_head_dim": 1024 // 16, + "num_layers": 24, + "embedding_dim": 1024, + "num_embeddings": 1024, + "additional_embeddings": 0, + "act_fn": "gelu", + "time_embed_dim": 1024 * 4, + "clip_embedding_dim": 768, + "out_dim": 1024 * 2, + "has_pre_norm": True, + "has_encoder_hidden_states_proj": False, + "has_prd_embedding": False, + "has_post_process": False, +} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # .time_embed.c_fc -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"], + } + ) + + # .time_embed.c_proj -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"], + } + ) + + # .clip_img_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"], + } + ) + + # .text_emb_proj -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"], + } + ) + + # .positional_embedding -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]}) + + # .ln_pre -> .norm_in + diffusers_checkpoint.update( + { + "norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"], + "norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"], + } + ) + + # .resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .final_ln -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"], + } + ) + + # .out_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +# Driver functions + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint) + + del prior_checkpoint + + load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) + + print("done loading prior") + + return prior_model + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + print("YiYi TO-DO") + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 402f6eaa74..359c058780 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -137,6 +137,7 @@ else: LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, + ShapEPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index b245612e6f..1d712ee672 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from typing import Optional, Union @@ -58,6 +59,14 @@ class PriorTransformer(ModelMixin, ConfigMixin): num_embeddings=77, additional_embeddings=4, dropout: float = 0.0, + act_fn: str = "silu", + has_pre_norm: bool = False, + has_encoder_hidden_states_proj: bool = True, + has_prd_embedding: bool = True, + has_post_process: bool = True, + time_embed_dim: Optional[int] = None, + clip_embedding_dim: Optional[int] = None, + out_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -65,17 +74,33 @@ class PriorTransformer(ModelMixin, ConfigMixin): inner_dim = num_attention_heads * attention_head_dim self.additional_embeddings = additional_embeddings + if time_embed_dim is None: + time_embed_dim = inner_dim + + if clip_embedding_dim is None: + clip_embedding_dim = embedding_dim + + if out_dim is None: + out_dim = embedding_dim + self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=act_fn) self.proj_in = nn.Linear(embedding_dim, inner_dim) - self.embedding_proj = nn.Linear(embedding_dim, inner_dim) - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + self.embedding_proj = nn.Linear(clip_embedding_dim, inner_dim) + + if has_encoder_hidden_states_proj: + self.encoder_hidden_states_proj = nn.Linear(clip_embedding_dim, inner_dim) + else: + self.encoder_hidden_states_proj = None self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + if has_prd_embedding: + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + else: + self.prd_embedding = None self.transformer_blocks = nn.ModuleList( [ @@ -91,8 +116,14 @@ class PriorTransformer(ModelMixin, ConfigMixin): ] ) + if has_pre_norm: + self.norm_in = nn.LayerNorm(inner_dim) + else: + self.norm_in = None + self.norm_out = nn.LayerNorm(inner_dim) - self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, out_dim) causal_attention_mask = torch.full( [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 @@ -100,16 +131,19 @@ class PriorTransformer(ModelMixin, ConfigMixin): causal_attention_mask.triu_(1) causal_attention_mask = causal_attention_mask[None, ...] self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) - - self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + if has_post_process: + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embedding_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embedding_dim)) + else: + self.clip_mean = None + self.clip_std = None def forward( self, hidden_states, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, ): @@ -152,23 +186,49 @@ class PriorTransformer(ModelMixin, ConfigMixin): timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) + # Rescale the features to have unit variance + # YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline + if self.clip_mean is None: + proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding proj_embeddings = self.embedding_proj(proj_embedding) - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + hidden_states = self.proj_in(hidden_states) - prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + tokens = [] + + if encoder_hidden_states is not None: + tokens.append(encoder_hidden_states) + + tokens = tokens + [ + proj_embeddings[:, None, :], + time_embeddings[:, None, :], + hidden_states[:, None, :] if len(hidden_states.shape) == 2 else hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + tokens.append(prd_embedding) + hidden_states = torch.cat( - [ - encoder_hidden_states, - proj_embeddings[:, None, :], - time_embeddings[:, None, :], - hidden_states[:, None, :], - prd_embedding, - ], + tokens, dim=1, ) + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings = 2 + (encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0) + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + (0, 0, additional_embeddings, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0), + value=0.0, + ) + hidden_states = hidden_states + positional_embeddings if attention_mask is not None: @@ -177,11 +237,19 @@ class PriorTransformer(ModelMixin, ConfigMixin): attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + for block in self.transformer_blocks: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states[:, -1] + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings:] + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) if not return_dict: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9e68538f23..1fd0c505ec 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -66,6 +66,7 @@ else: from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, diff --git a/src/diffusers/pipelines/shap_e/__init__.py b/src/diffusers/pipelines/shap_e/__init__.py new file mode 100644 index 0000000000..bc8c04d50a --- /dev/null +++ b/src/diffusers/pipelines/shap_e/__init__.py @@ -0,0 +1,15 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .pipeline_shap_e import ShapEPipeline diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py new file mode 100644 index 0000000000..182ca699bd --- /dev/null +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -0,0 +1,311 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_accelerate_available, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for ShapEPipeline. + + Args: + images (`torch.FloatTensor`) + 3D latent representation + """ + + latents: Union[torch.FloatTensor, np.ndarray] + + +class ShapEPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset with Shap.E + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + def __init__( + self, + prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: HeunDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + latents = latents * scheduler.init_noise_sigma + return latents + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): + return self.device + for module in self.text_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + len(prompt) if isinstance(prompt, list) else 1 + + # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file + self.tokenizer.pad_token_id = 0 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + prompt_embeds = text_encoder_output.text_embeds + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # in Shap-E it normalize the prompt_embeds and then later rescale it, not sure why + # YiYi TO-DO: move rescale out of prior_transformer and apply it here + prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + sigma_min: float = 1e-3, + sigma_max: float = 160.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`ShapEPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + # prior + + self.scheduler.set_timesteps( + num_inference_steps, device=device, sigma_min=sigma_min, sigma_max=sigma_max, use_karras_sigmas=True + ) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + # for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=prompt_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + # clip between -1 and 1 + noise_pred = noise_pred.clamp(-1, 1) + + if do_classifier_free_guidance is not None: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + step_index=i, + ).prev_sample + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + latents = latents.cpu().numpy() + + if not return_dict: + return latents + + return ShapEPipelineOutput(latents) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea..93465b2d63 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -22,8 +22,11 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_bar_fn=None, +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -44,11 +47,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + if alpha_bar_fn is None: + alpha_bar_fn = alpha_bar + betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -106,6 +112,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "exp": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_bar_fn=lambda t: math.exp(t * -12.0)) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -152,6 +160,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + use_karras_sigmas: Optional[bool] = None, # overwrite the self.config.use_karras_sigma ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -166,15 +177,25 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + if sigma_min is not None and sigma_max is not None: + sigmas = torch.tensor([sigma_max, sigma_min]) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + else: + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if use_karras_sigmas is None: + use_karras_sigmas = self.use_karras_sigmas + + if use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + if self.config.beta_schedule == "exp": + timesteps = np.array([self._sigma_to_t_yiyi(sigma) for sigma in sigmas]) + else: + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -220,6 +241,22 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t + # YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy + def _sigma_to_t_yiyi(self, sigma): + alpha_cumprod = 1.0 / (sigma**2 + 1) + + if alpha_cumprod > self.alphas_cumprod[0]: + return 0 + elif alpha_cumprod <= self.alphas_cumprod[-1]: + return len(self.alphas_cumprod) - 1 + else: + from scipy import interpolate + + timestep = interpolate.interp1d(self.alphas_cumprod, np.arange(0, len(self.alphas_cumprod)))( + alpha_cumprod + ) # yiyi testing, origin implementation + return int(timestep) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -244,6 +281,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, + step_index: Optional[int] = None, ) -> Union[SchedulerOutput, Tuple]: """ Args: @@ -258,7 +296,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - step_index = self.index_for_timestep(timestep) + if step_index is None: + step_index = self.index_for_timestep(timestep) if self.state_in_first_order: sigma = self.sigmas[step_index] @@ -284,7 +323,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") + pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 95d07c081c..cc060b5572 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -227,6 +227,21 @@ class KandinskyPriorPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class ShapEPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]