mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor prior_transformer
adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function
This commit is contained in:
324
scripts/convert_shap_e_to_diffusers.py
Normal file
324
scripts/convert_shap_e_to_diffusers.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
|
||||
|
||||
"""
|
||||
Example - From the diffusers root directory:
|
||||
|
||||
Download weights:
|
||||
```sh
|
||||
$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt"
|
||||
```
|
||||
|
||||
Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shape \
|
||||
--debug prior
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# prior
|
||||
|
||||
PRIOR_ORIGINAL_PREFIX = "wrapped"
|
||||
|
||||
# Uses default arguments
|
||||
PRIOR_CONFIG = {
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 1024 // 16,
|
||||
"num_layers": 24,
|
||||
"embedding_dim": 1024,
|
||||
"num_embeddings": 1024,
|
||||
"additional_embeddings": 0,
|
||||
"act_fn": "gelu",
|
||||
"time_embed_dim": 1024 * 4,
|
||||
"clip_embedding_dim": 768,
|
||||
"out_dim": 1024 * 2,
|
||||
"has_pre_norm": True,
|
||||
"has_encoder_hidden_states_proj": False,
|
||||
"has_prd_embedding": False,
|
||||
"has_post_process": False,
|
||||
}
|
||||
|
||||
|
||||
def prior_model_from_original_config():
|
||||
model = PriorTransformer(**PRIOR_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
|
||||
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
|
||||
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.clip_img_proj -> <diffusers>.proj_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"],
|
||||
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.text_emb_proj -> <diffusers>.embedding_proj
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"],
|
||||
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.positional_embedding -> <diffusers>.positional_embedding
|
||||
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]})
|
||||
|
||||
# <original>.ln_pre -> <diffusers>.norm_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"],
|
||||
"norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
||||
for idx in range(len(model.transformer_blocks)):
|
||||
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
||||
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
|
||||
|
||||
# <original>.attn -> <diffusers>.attn1
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
||||
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
||||
diffusers_checkpoint.update(
|
||||
prior_attention_to_diffusers(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
original_attention_prefix=original_attention_prefix,
|
||||
attention_head_dim=model.attention_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.mlp -> <diffusers>.ff
|
||||
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
||||
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
||||
diffusers_checkpoint.update(
|
||||
prior_ff_to_diffusers(
|
||||
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.ln_1 -> <diffusers>.norm1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_1.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_2 -> <diffusers>.norm3
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_2.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.final_ln -> <diffusers>.norm_out
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"],
|
||||
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"],
|
||||
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_attention_to_diffusers(
|
||||
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
|
||||
):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
|
||||
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
||||
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
|
||||
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
|
||||
split=3,
|
||||
chunk_size=attention_head_dim,
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
||||
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
||||
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
||||
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
||||
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
||||
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.c_proj -> <diffusers>.to_out.0
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
|
||||
diffusers_checkpoint = {
|
||||
# <original>.c_fc -> <diffusers>.net.0.proj
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
|
||||
# <original>.c_proj -> <diffusers>.net.2
|
||||
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
|
||||
}
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done prior
|
||||
|
||||
|
||||
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
||||
def split_attentions(*, weight, bias, split, chunk_size):
|
||||
weights = [None] * split
|
||||
biases = [None] * split
|
||||
|
||||
weights_biases_idx = 0
|
||||
|
||||
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
||||
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
||||
|
||||
weight_rows = weight[row_indices, :]
|
||||
bias_rows = bias[row_indices]
|
||||
|
||||
if weights[weights_biases_idx] is None:
|
||||
assert weights[weights_biases_idx] is None
|
||||
weights[weights_biases_idx] = weight_rows
|
||||
biases[weights_biases_idx] = bias_rows
|
||||
else:
|
||||
assert weights[weights_biases_idx] is not None
|
||||
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
||||
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
||||
|
||||
weights_biases_idx = (weights_biases_idx + 1) % split
|
||||
|
||||
return weights, biases
|
||||
|
||||
|
||||
# done unet utils
|
||||
|
||||
|
||||
# Driver functions
|
||||
|
||||
|
||||
def prior(*, args, checkpoint_map_location):
|
||||
print("loading prior")
|
||||
|
||||
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
prior_model = prior_model_from_original_config()
|
||||
|
||||
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
|
||||
|
||||
del prior_checkpoint
|
||||
|
||||
load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
|
||||
|
||||
print("done loading prior")
|
||||
|
||||
return prior_model
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
if strict:
|
||||
model.load_state_dict(torch.load(file.name), strict=True)
|
||||
else:
|
||||
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--prior_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the prior checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Only run a specific stage of the convert script. Used for debugging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
if args.debug is not None:
|
||||
print(f"debug: only executing {args.debug}")
|
||||
|
||||
if args.debug is None:
|
||||
print("YiYi TO-DO")
|
||||
elif args.debug == "prior":
|
||||
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
else:
|
||||
raise ValueError(f"unknown debug value : {args.debug}")
|
||||
@@ -137,6 +137,7 @@ else:
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -58,6 +59,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
has_pre_norm: bool = False,
|
||||
has_encoder_hidden_states_proj: bool = True,
|
||||
has_prd_embedding: bool = True,
|
||||
has_post_process: bool = True,
|
||||
time_embed_dim: Optional[int] = None,
|
||||
clip_embedding_dim: Optional[int] = None,
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -65,17 +74,33 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
if time_embed_dim is None:
|
||||
time_embed_dim = inner_dim
|
||||
|
||||
if clip_embedding_dim is None:
|
||||
clip_embedding_dim = embedding_dim
|
||||
|
||||
if out_dim is None:
|
||||
out_dim = embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.embedding_proj = nn.Linear(clip_embedding_dim, inner_dim)
|
||||
|
||||
if has_encoder_hidden_states_proj:
|
||||
self.encoder_hidden_states_proj = nn.Linear(clip_embedding_dim, inner_dim)
|
||||
else:
|
||||
self.encoder_hidden_states_proj = None
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
if has_prd_embedding:
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
else:
|
||||
self.prd_embedding = None
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -91,8 +116,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
]
|
||||
)
|
||||
|
||||
if has_pre_norm:
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
else:
|
||||
self.norm_in = None
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, out_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
@@ -100,16 +131,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
if has_post_process:
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embedding_dim))
|
||||
else:
|
||||
self.clip_mean = None
|
||||
self.clip_std = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
@@ -152,23 +186,49 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
# Rescale the features to have unit variance
|
||||
# YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
|
||||
if self.clip_mean is None:
|
||||
proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
tokens = []
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
tokens.append(encoder_hidden_states)
|
||||
|
||||
tokens = tokens + [
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :] if len(hidden_states.shape) == 2 else hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
tokens.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
encoder_hidden_states,
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :],
|
||||
prd_embedding,
|
||||
],
|
||||
tokens,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings = 2 + (encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0)
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(0, 0, additional_embeddings, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -177,11 +237,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states[:, -1]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -66,6 +66,7 @@ else:
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
|
||||
15
src/diffusers/pipelines/shap_e/__init__.py
Normal file
15
src/diffusers/pipelines/shap_e/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .pipeline_shap_e import ShapEPipeline
|
||||
311
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
Normal file
311
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapEPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ShapEPipeline.
|
||||
|
||||
Args:
|
||||
images (`torch.FloatTensor`)
|
||||
3D latent representation
|
||||
"""
|
||||
|
||||
latents: Union[torch.FloatTensor, np.ndarray]
|
||||
|
||||
|
||||
class ShapEPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating latent representation of a 3D asset with Shap.E
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.text_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
|
||||
self.tokenizer.pad_token_id = 0
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
# in Shap-E it normalize the prompt_embeds and then later rescale it, not sure why
|
||||
# YiYi TO-DO: move rescale out of prior_transformer and apply it here
|
||||
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
sigma_min: float = 1e-3,
|
||||
sigma_max: float = 160.0,
|
||||
output_type: Optional[str] = "pt", # pt only
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`ShapEPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
# prior
|
||||
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=device, sigma_min=sigma_min, sigma_max=sigma_max, use_karras_sigmas=True
|
||||
)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_embeddings = self.prior.config.num_embeddings
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_embeddings * embedding_dim),
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
# for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
|
||||
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.prior(
|
||||
scaled_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
).predicted_image_embedding
|
||||
|
||||
# remove the variance
|
||||
noise_pred, _ = noise_pred.split(
|
||||
scaled_model_input.shape[2], dim=2
|
||||
) # batch_size, num_embeddings, embedding_dim
|
||||
|
||||
# clip between -1 and 1
|
||||
noise_pred = noise_pred.clamp(-1, 1)
|
||||
|
||||
if do_classifier_free_guidance is not None:
|
||||
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
step_index=i,
|
||||
).prev_sample
|
||||
|
||||
if output_type not in ["pt", "np"]:
|
||||
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type == "np":
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
if not return_dict:
|
||||
return latents
|
||||
|
||||
return ShapEPipelineOutput(latents)
|
||||
@@ -22,8 +22,11 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_bar_fn=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -44,11 +47,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
if alpha_bar_fn is None:
|
||||
alpha_bar_fn = alpha_bar
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -106,6 +112,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "exp":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_bar_fn=lambda t: math.exp(t * -12.0))
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -152,6 +160,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
use_karras_sigmas: Optional[bool] = None, # overwrite the self.config.use_karras_sigma
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -166,15 +177,25 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
if sigma_min is not None and sigma_max is not None:
|
||||
sigmas = torch.tensor([sigma_max, sigma_min])
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
else:
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if use_karras_sigmas is None:
|
||||
use_karras_sigmas = self.use_karras_sigmas
|
||||
|
||||
if use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
if self.config.beta_schedule == "exp":
|
||||
timesteps = np.array([self._sigma_to_t_yiyi(sigma) for sigma in sigmas])
|
||||
else:
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
@@ -220,6 +241,22 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy
|
||||
def _sigma_to_t_yiyi(self, sigma):
|
||||
alpha_cumprod = 1.0 / (sigma**2 + 1)
|
||||
|
||||
if alpha_cumprod > self.alphas_cumprod[0]:
|
||||
return 0
|
||||
elif alpha_cumprod <= self.alphas_cumprod[-1]:
|
||||
return len(self.alphas_cumprod) - 1
|
||||
else:
|
||||
from scipy import interpolate
|
||||
|
||||
timestep = interpolate.interp1d(self.alphas_cumprod, np.arange(0, len(self.alphas_cumprod)))(
|
||||
alpha_cumprod
|
||||
) # yiyi testing, origin implementation
|
||||
return int(timestep)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -244,6 +281,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
step_index: Optional[int] = None,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
@@ -258,7 +296,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
if step_index is None:
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
@@ -284,7 +323,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample / (sigma_input**2 + 1)
|
||||
)
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
|
||||
@@ -227,6 +227,21 @@ class KandinskyPriorPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ShapEPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user