1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
anton-l
2022-06-08 11:53:12 +02:00
parent bb98a5b709
commit 07ffe73f79
11 changed files with 91 additions and 96 deletions

View File

@@ -1,9 +1,10 @@
import torch
from torch import nn
from transformers import CLIPTextConfig, GPT2Tokenizer
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
@@ -22,7 +23,7 @@ config = CLIPTextConfig(
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
#tokenizer.save_pretrained("./glide-base")
# tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model
@@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#with torch.no_grad():
# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
# with torch.no_grad():
# outputs = model(**inputs)
#model.save_pretrained("./glide-base")
# model.save_pretrained("./glide-base")
### Convert the UNet
@@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
glide.save_pretrained("./glide-base")
glide.save_pretrained("./glide-base")

View File

@@ -14,12 +14,12 @@
# limitations under the License.
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from transformers import GPT2Tokenizer
import numpy as np
import torch
import tqdm
import torch
import numpy as np
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
from transformers import GPT2Tokenizer
def _extract_into_tensor(arr, timesteps, broadcast_shape):
@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline):
def __init__(
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
@@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline):
self.text_encoder.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
@@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
noise = self.noise_scheduler.sample_noise(image.shape)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
) # no noise when t == 0
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
return image

View File

@@ -1,6 +1,8 @@
import torch
from modeling_glide import GLIDE
generator = torch.Generator()
generator = generator.manual_seed(0)

View File

@@ -5,10 +5,10 @@
__version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel
from .models.unet_ldm import UNetLDMModel
from .models.clip_text_transformer import CLIPTextModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler

View File

@@ -89,7 +89,6 @@ class ConfigMixin:
self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod
def get_config_dict(
@@ -183,7 +182,7 @@ class ConfigMixin:
logger.info(f"loading configuration file {config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict
@classmethod
@@ -199,9 +198,8 @@ class ConfigMixin:
# use value from config dict
init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warn(
@@ -212,9 +210,7 @@ class ConfigMixin:
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)

View File

@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel
from .unet_glide import UNetGLIDEModel
from .unet_ldm import UNetLDMModel
from .clip_text_transformer import CLIPTextModel

View File

@@ -14,14 +14,15 @@
# limitations under the License.
""" PyTorch CLIP model."""
from dataclasses import dataclass
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
@@ -32,7 +33,7 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from transformers import CLIPModel, CLIPConfig, CLIPVisionConfig, CLIPTextConfig
logger = logging.get_logger(__name__)
@@ -153,11 +154,11 @@ class CLIPTextEmbeddings(nn.Module):
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
@@ -193,16 +194,15 @@ class CLIPAttention(nn.Module):
)
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim*3)
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -212,9 +212,7 @@ class CLIPAttention(nn.Module):
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
attn_weights = torch.einsum(
"bthc,bshc->bhts", query_states * self.scale, key_states * self.scale
)
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
wdtype = attn_weights.dtype
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
@@ -252,11 +250,11 @@ class CLIPEncoderLayer(nn.Module):
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
@@ -313,19 +311,19 @@ class CLIPPreTrainedModel(PreTrainedModel):
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
elif isinstance(module, CLIPVisionEmbeddings):
factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor)
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, CLIPAttention):
factor = self.config.initializer_factor
in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim ** -0.5) * factor
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, CLIPMLP):
factor = self.config.initializer_factor
in_proj_std = (
(module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
)
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
@@ -333,11 +331,11 @@ class CLIPPreTrainedModel(PreTrainedModel):
elif isinstance(module, CLIPModel):
nn.init.normal_(
module.text_projection.weight,
std=module.text_embed_dim ** -0.5 * self.config.initializer_factor,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
if isinstance(module, nn.LayerNorm):
@@ -463,13 +461,13 @@ class CLIPEncoder(nn.Module):
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
@@ -562,13 +560,13 @@ class CLIPTextTransformer(nn.Module):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@@ -652,13 +650,13 @@ class CLIPTextModel(CLIPPreTrainedModel):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
@@ -684,4 +682,4 @@ class CLIPTextModel(CLIPPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)

View File

@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
#self.dtype = torch.float16 if use_fp16 else torch.float32
# self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample

View File

@@ -17,6 +17,7 @@
import importlib
import os
from typing import Optional, Union
from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils
@@ -64,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):

View File

@@ -16,5 +16,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .gaussian_ddpm import GaussianDDPMScheduler
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler

View File

@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from torch import nn
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
@@ -80,19 +81,13 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic