mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of github.com:huggingface/diffusers
This commit is contained in:
4
Makefile
4
Makefile
@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
||||
# Make marked copies of snippets of codes conform to the original
|
||||
|
||||
fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
|
||||
31
README.md
31
README.md
@@ -30,20 +30,32 @@ More precisely, 🤗 Diffusers offers:
|
||||
**Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
|
||||
*Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet
|
||||
|
||||

|
||||
|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/10695622/174349667-04e9e485-793b-429a-affe-096e8199ad5b.png" width="800"/>
|
||||
<br>
|
||||
<em> Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
|
||||
<p>
|
||||
|
||||
**Schedulers**: Algorithm class for both **inference** and **training**.
|
||||
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
|
||||
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)
|
||||
|
||||

|
||||

|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/10695622/174349706-53d58acc-a4d1-4cda-b3e8-432d9dc7ad38.png" width="800"/>
|
||||
<br>
|
||||
<em> Sampling and training algorithms. Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
|
||||
<p>
|
||||
|
||||
|
||||
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
|
||||
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2
|
||||
|
||||

|
||||
|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
|
||||
<br>
|
||||
<em> Figure from ImageGen (https://imagen.research.google/). </em>
|
||||
<p>
|
||||
|
||||
## Philosophy
|
||||
|
||||
- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
|
||||
@@ -147,7 +159,8 @@ eta = 0.0 # <- deterministic sampling
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# 1. predict noise residual
|
||||
orig_t = noise_scheduler.get_orig_t(t, num_inference_steps)
|
||||
orig_t = len(noise_scheduler) // num_inference_steps * t
|
||||
|
||||
with torch.inference_mode():
|
||||
residual = unet(image, orig_t)
|
||||
|
||||
@@ -173,6 +186,10 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
#### **Examples for other modalities:**
|
||||
|
||||
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
|
||||
|
||||
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...)
|
||||
|
||||
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
from .utils import is_transformers_available
|
||||
|
||||
|
||||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.unet_rl import TemporalUNet
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
|
||||
from .pipelines import BDDM, DDIM, DDPM, PNDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .pipelines import GLIDE, GradTTS, LatentDiffusion
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
||||
@@ -241,7 +241,7 @@ class ConfigMixin:
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
config_dict = self._internal_dict
|
||||
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
@@ -258,10 +258,6 @@ class ConfigMixin:
|
||||
|
||||
class FrozenDict(OrderedDict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# remove `None`
|
||||
args = (a for a in args if a is not None)
|
||||
kwargs = {k: v for k, v in kwargs if v is not None}
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for key, value in self.items():
|
||||
|
||||
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warninging(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warninging(
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warninging(
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
|
||||
@@ -20,3 +20,4 @@ from .unet import UNetModel
|
||||
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_rl import TemporalUNet
|
||||
@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t):
|
||||
def forward(self, x, timesteps):
|
||||
assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if not torch.is_tensor(t):
|
||||
t = torch.tensor([t], dtype=torch.long, device=x.device)
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||
|
||||
# timestep embedding
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = get_timestep_embedding(timesteps, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
|
||||
@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
self.final_block = Block(dim, dim)
|
||||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, x, mask, mu, t, spk=None):
|
||||
def forward(self, x, timesteps, mu, mask, spk=None):
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spk = self.spk_emb(spk)
|
||||
@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
t = self.time_pos_emb(t, scale=self.pe_scale)
|
||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
if self.n_spks < 2:
|
||||
|
||||
268
src/diffusers/models/unet_rl.py
Normal file
268
src/diffusers/models/unet_rl.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
Rearrange("batch channels horizon -> batch channels 1 horizon"),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
Rearrange("batch channels 1 horizon -> batch channels horizon"),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size),
|
||||
]
|
||||
)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(embed_dim, out_channels),
|
||||
Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
x : [ batch_size x inp_channels x horizon ]
|
||||
t : [ batch_size x embed_dim ]
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x) + self.time_mlp(t)
|
||||
out = self.blocks[1](out)
|
||||
return out + self.residual_conv(x)
|
||||
|
||||
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
print(in_out)
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon // 2
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon * 2
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=5),
|
||||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, cond, time):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
|
||||
x = einops.rearrange(x, "b h t -> b t h")
|
||||
|
||||
t = self.time_mlp(time)
|
||||
h = []
|
||||
|
||||
for resnet, resnet2, downsample in self.downs:
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for resnet, resnet2, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b t h -> b h t")
|
||||
return x
|
||||
|
||||
|
||||
class TemporalValue(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = time_dim or dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
print(in_out)
|
||||
for dim_in, dim_out in in_out:
|
||||
|
||||
self.blocks.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
|
||||
Downsample1d(dim_out),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
horizon = horizon // 2
|
||||
|
||||
fc_dim = dims[-1] * max(horizon, 1)
|
||||
|
||||
self.final_block = nn.Sequential(
|
||||
nn.Linear(fc_dim + time_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
nn.Linear(fc_dim // 2, out_dim),
|
||||
)
|
||||
|
||||
def forward(self, x, cond, time, *args):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
|
||||
x = einops.rearrange(x, "b h t -> b t h")
|
||||
|
||||
t = self.time_mlp(time)
|
||||
|
||||
for resnet, resnet2, downsample in self.blocks:
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
x = downsample(x)
|
||||
|
||||
x = x.view(len(x), -1)
|
||||
out = self.final_block(torch.cat([x, t], dim=-1))
|
||||
return out
|
||||
@@ -1,19 +0,0 @@
|
||||
# Pipelines
|
||||
|
||||
- Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box
|
||||
- Pipelines should stay as close as possible to their original implementation
|
||||
- Pipelines can include components of other library, such as text-encoders.
|
||||
|
||||
## API
|
||||
|
||||
TODO(Patrick, Anton, Suraj)
|
||||
|
||||
## Examples
|
||||
|
||||
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- Latent diffusion for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- Glide for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- BDDM for spectrogram-to-sound vocoding in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
@@ -1,16 +1,11 @@
|
||||
from ..utils import is_transformers_available
|
||||
from .pipeline_bddm import BDDM
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_grad_tts import GradTTS
|
||||
|
||||
|
||||
try:
|
||||
from .pipeline_glide import GLIDE
|
||||
except (NameError, ImportError):
|
||||
|
||||
class GLIDE:
|
||||
pass
|
||||
|
||||
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_pndm import PNDM
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_glide import GLIDE
|
||||
from .pipeline_grad_tts import GradTTS
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
|
||||
@@ -6,11 +6,8 @@ from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
try:
|
||||
from transformers import PreTrainedTokenizer
|
||||
except:
|
||||
print("transformers is not installed")
|
||||
|
||||
try:
|
||||
from unidecode import unidecode
|
||||
@@ -237,7 +234,12 @@ def english_cleaners(text):
|
||||
return text
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
try:
|
||||
_inflect = inflect.engine()
|
||||
except:
|
||||
print("inflect is not installed")
|
||||
_inflect = None
|
||||
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Denoising Diffusion Implicit Models (DDIM)
|
||||
|
||||
## Overview
|
||||
|
||||
DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon*
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
|
||||
|
||||
Tips:
|
||||
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
|
||||
@@ -1 +0,0 @@
|
||||
from .pipeline_ddim import DDIM
|
||||
@@ -1,26 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from modeling_ddim import DDIM
|
||||
|
||||
|
||||
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
|
||||
|
||||
for model_id in model_ids:
|
||||
path = os.path.join("/home/patrick/images/hf", model_id)
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ddpm = DDIM.from_pretrained("fusing/" + model_id)
|
||||
image = ddpm(batch_size=4)
|
||||
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
|
||||
for i in range(image_processed.shape[0]):
|
||||
image_pil = PIL.Image.fromarray(image_processed[i])
|
||||
image_pil.save(os.path.join(path, f"image_{i}.png"))
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import DDPMScheduler, UNetModel
|
||||
|
||||
|
||||
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
|
||||
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
|
||||
loss = diffusion(training_images)
|
||||
loss.backward()
|
||||
# after a lot of training
|
||||
|
||||
sampled_images = diffusion.sample(batch_size=4)
|
||||
sampled_images.shape # (4, 3, 128, 128)
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# !pip install diffusers
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from modeling_ddim import DDIM
|
||||
|
||||
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
model_id = "fusing/ddpm-lsun-bedroom"
|
||||
|
||||
# load model and scheduler
|
||||
ddpm = DDIM.from_pretrained(model_id)
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = ddpm()
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
image_pil.save("/home/patrick/images/show.png")
|
||||
@@ -1,30 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Denoising Diffusion Probabilistic Models (DDPM)
|
||||
|
||||
## Overview
|
||||
|
||||
DDPM was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) by *Jonathan Ho, Ajay Jain, Pieter Abbeel*.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN. Our implementation is available at this https URL*
|
||||
|
||||
Tips:
|
||||
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
|
||||
|
||||

|
||||
@@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from modeling_ddpm import DDPM
|
||||
|
||||
|
||||
model_ids = [
|
||||
"ddpm-lsun-cat",
|
||||
"ddpm-lsun-cat-ema",
|
||||
"ddpm-lsun-church-ema",
|
||||
"ddpm-lsun-church",
|
||||
"ddpm-lsun-bedroom",
|
||||
"ddpm-lsun-bedroom-ema",
|
||||
"ddpm-cifar10-ema",
|
||||
"ddpm-cifar10",
|
||||
"ddpm-celeba-hq",
|
||||
"ddpm-celeba-hq-ema",
|
||||
]
|
||||
|
||||
for model_id in model_ids:
|
||||
path = os.path.join("/home/patrick/images/hf", model_id)
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ddpm = DDPM.from_pretrained("fusing/" + model_id)
|
||||
image = ddpm(batch_size=4)
|
||||
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
|
||||
for i in range(image_processed.shape[0]):
|
||||
image_pil = PIL.Image.fromarray(image_processed[i])
|
||||
image_pil.save(os.path.join(path, f"image_{i}.png"))
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import DDPMScheduler, UNetModel
|
||||
|
||||
|
||||
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
|
||||
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
|
||||
loss = diffusion(training_images)
|
||||
loss.backward()
|
||||
# after a lot of training
|
||||
|
||||
sampled_images = diffusion.sample(batch_size=4)
|
||||
sampled_images.shape # (4, 3, 128, 128)
|
||||
@@ -1,4 +0,0 @@
|
||||
# References
|
||||
|
||||
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
|
||||
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
|
||||
@@ -1,111 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from modeling_glide import GLIDE, CLIPTextModel
|
||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
|
||||
state_dict = torch.load("base.pt", map_location="cpu")
|
||||
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
|
||||
|
||||
### Convert the text encoder
|
||||
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=50257,
|
||||
max_position_embeddings=128,
|
||||
hidden_size=512,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=8,
|
||||
use_padding_embeddings=True,
|
||||
)
|
||||
model = CLIPTextModel(config).eval()
|
||||
tokenizer = GPT2Tokenizer(
|
||||
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
|
||||
)
|
||||
|
||||
hf_encoder = model.text_model
|
||||
|
||||
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
|
||||
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
|
||||
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
|
||||
|
||||
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
|
||||
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
|
||||
|
||||
for layer_idx in range(config.num_hidden_layers):
|
||||
hf_layer = hf_encoder.encoder.layers[layer_idx]
|
||||
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
|
||||
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
|
||||
|
||||
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
|
||||
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
|
||||
|
||||
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
|
||||
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
|
||||
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
|
||||
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
|
||||
|
||||
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
|
||||
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
|
||||
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
||||
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
||||
|
||||
### Convert the Text-to-Image UNet
|
||||
|
||||
text2im_model = GLIDETextToImageUNetModel(
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 2, 3, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
transformer_dim=512,
|
||||
)
|
||||
|
||||
text2im_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
||||
|
||||
### Convert the Super-Resolution UNet
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
|
||||
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
|
||||
|
||||
superres_model = GLIDESuperResUNetModel(
|
||||
in_channels=6,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(8, 16, 32),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 1, 2, 2, 4, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
)
|
||||
|
||||
superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||
|
||||
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear")
|
||||
|
||||
glide = GLIDE(
|
||||
text_unet=text2im_model,
|
||||
text_noise_scheduler=text_scheduler,
|
||||
text_encoder=model,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=superres_model,
|
||||
upscale_noise_scheduler=upscale_scheduler,
|
||||
)
|
||||
|
||||
glide.save_pretrained("./glide-base")
|
||||
@@ -1,923 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch CLIP model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidanceScheduler,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
GLIDESuperResUNetModel,
|
||||
GLIDETextToImageUNetModel,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
#####################
|
||||
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
|
||||
#####################
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "fusing/glide-base"
|
||||
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"fusing/glide-base",
|
||||
# See all CLIP models at https://huggingface.co/models?filter=clip
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# contrastive loss function, adapted from
|
||||
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
||||
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
||||
|
||||
|
||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||
caption_loss = contrastive_loss(similarity)
|
||||
image_loss = contrastive_loss(similarity.T)
|
||||
return (caption_loss + image_loss) / 2.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||
Contrastive loss for image-text similarity.
|
||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||
similarity scores.
|
||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||
similarity scores.
|
||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||
text_model_output(`BaseModelOutputWithPooling`):
|
||||
The output of the [`CLIPTextModel`].
|
||||
vision_model_output(`BaseModelOutputWithPooling`):
|
||||
The output of the [`CLIPVisionModel`].
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits_per_image: torch.FloatTensor = None
|
||||
logits_per_text: torch.FloatTensor = None
|
||||
text_embeds: torch.FloatTensor = None
|
||||
image_embeds: torch.FloatTensor = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPooling = None
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
return tuple(
|
||||
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||||
for k in self.keys()
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
self.use_padding_embeddings = config.use_padding_embeddings
|
||||
if self.use_padding_embeddings:
|
||||
self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
if self.use_padding_embeddings and attention_mask is not None:
|
||||
padding_embeddings = self.padding_embedding(position_ids)
|
||||
embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
|
||||
|
||||
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
|
||||
|
||||
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
|
||||
|
||||
wdtype = attn_weights.dtype
|
||||
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
|
||||
|
||||
attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
|
||||
self.mlp = CLIPMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CLIPPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = CLIPConfig
|
||||
base_model_prefix = "clip"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPTextEmbeddings):
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
if hasattr(module, "padding_embedding"):
|
||||
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
||||
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
||||
elif isinstance(module, CLIPAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
out_proj_std = (module.embed_dim**-0.5) * factor
|
||||
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
||||
elif isinstance(module, CLIPMLP):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (
|
||||
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
)
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
elif isinstance(module, CLIPModel):
|
||||
nn.init.normal_(
|
||||
module.text_projection.weight,
|
||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
nn.init.normal_(
|
||||
module.visual_projection.weight,
|
||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, CLIPEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=None,
|
||||
causal_attention_mask=None,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_causal_attention_mask(self, bsz, seq_len):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len)
|
||||
mask.fill_(torch.tensor(float("-inf")))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.embeddings.token_embedding = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
#####################
|
||||
# END OF THE CLIP MODEL COPY-PASTE
|
||||
#####################
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
||||
|
||||
|
||||
class GLIDE(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
text_unet: GLIDETextToImageUNetModel,
|
||||
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GLIDESuperResUNetModel,
|
||||
upscale_noise_scheduler: DDIMScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
text_unet=text_unet,
|
||||
text_noise_scheduler=text_noise_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet,
|
||||
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
|
||||
q(x_{t-1} | x_t, x_0)
|
||||
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
||||
:param model: the model, which takes a signal and a batch of timesteps
|
||||
as input.
|
||||
:param x: the [N x C x ...] tensor at time t.
|
||||
:param t: a 1-D Tensor of timesteps.
|
||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict with the following keys:
|
||||
- 'mean': the model mean output.
|
||||
- 'variance': the model variance output.
|
||||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
if transformer_out is None:
|
||||
# super-res model
|
||||
model_output = model(x, t, low_res)
|
||||
else:
|
||||
# text2image model
|
||||
model_output = model(x, t, transformer_out)
|
||||
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = torch.exp(model_log_variance)
|
||||
|
||||
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
||||
if clip_denoised:
|
||||
pred_xstart = pred_xstart.clamp(-1, 1)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return model_mean, model_variance, model_log_variance, pred_xstart
|
||||
|
||||
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, generator=None, torch_device=None):
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.text_unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
self.upscale_unet.to(torch_device)
|
||||
|
||||
# Create a classifier-free guidance sampling function
|
||||
guidance_scale = 3.0
|
||||
|
||||
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = self.text_noise_scheduler.sample_noise(
|
||||
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from (
|
||||
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
attention_mask = inputs["attention_mask"].to(torch_device)
|
||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
# 3. Run the text2image generation step
|
||||
num_timesteps = len(self.text_noise_scheduler)
|
||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
||||
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
||||
)
|
||||
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
||||
eta = 0.0
|
||||
|
||||
# Tune this parameter to control the sharpness of 256x256 images.
|
||||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
||||
upsample_temp = 0.997
|
||||
|
||||
image = (
|
||||
self.upscale_noise_scheduler.sample_noise(
|
||||
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
||||
)
|
||||
* upsample_temp
|
||||
)
|
||||
|
||||
num_timesteps = len(self.upscale_noise_scheduler)
|
||||
for t in tqdm.tqdm(
|
||||
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
||||
):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (
|
||||
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
||||
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||
* self.upscale_noise_scheduler.get_beta(t)
|
||||
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = self.upscale_unet(image, time_input, low_res)
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
||||
t, prev_image.shape, device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
return image
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
import PIL.Image
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
|
||||
model_id = "fusing/glide-base"
|
||||
|
||||
# load model and scheduler
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
# run inference (text-conditioned denoising + upscaling)
|
||||
img = pipeline("a crayon drawing of a corgi", generator)
|
||||
|
||||
# process image to PIL
|
||||
img = img.squeeze(0)
|
||||
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
image_pil = PIL.Image.fromarray(img)
|
||||
|
||||
# save image
|
||||
image_pil.save("test.png")
|
||||
@@ -1,146 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LDMBERT model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class LDMBertConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LDMBertModel`]. It is used to instantiate a
|
||||
LDMBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the LDMBERT
|
||||
[facebook/ldmbert-large](https://huggingface.co/facebook/ldmbert-large) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50265):
|
||||
Vocabulary size of the LDMBERT model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`LDMBertModel`] or [`TFLDMBertModel`].
|
||||
d_model (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
num_labels: (`int`, *optional*, defaults to 3):
|
||||
The number of labels to use in [`LDMBertForSequenceClassification`].
|
||||
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||
`eos_token_id`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import LDMBertModel, LDMBertConfig
|
||||
|
||||
>>> # Initializing a LDMBERT facebook/ldmbert-large style configuration
|
||||
>>> configuration = LDMBertConfig()
|
||||
|
||||
>>> # Initializing a model from the facebook/ldmbert-large style configuration
|
||||
>>> model = LDMBertModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "ldmbert"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_position_embeddings=77,
|
||||
encoder_layers=32,
|
||||
encoder_ffn_dim=5120,
|
||||
encoder_attention_heads=8,
|
||||
head_dim=64,
|
||||
encoder_layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
d_model=1280,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
@@ -1,107 +0,0 @@
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from .configuration_ldmbert import LDMBertConfig # NOQA
|
||||
from .modeling_ldmbert import LDMBertModel # NOQA
|
||||
|
||||
# add these relative imports here, so we can load from hub
|
||||
from .modeling_vae import AutoencoderKL # NOQA
|
||||
|
||||
|
||||
class LatentDiffusion(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)[0]
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# guidance_scale of 1 means no guidance
|
||||
if guidance_scale == 1.0:
|
||||
image_in = image
|
||||
context = text_embedding
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
else:
|
||||
# for classifier free guidance, we need to do two forward passes
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embedding])
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
@@ -1,706 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch LDMBERT model."""
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from .configuration_ldmbert import LDMBertConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "ldm-bert"
|
||||
_CONFIG_FOR_DOC = "LDMBertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
||||
|
||||
# Base model docstring
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
|
||||
|
||||
# SequenceClassification docstring
|
||||
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/ldmbert-large-sst2"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.0
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
|
||||
|
||||
# QuestionAsnwering docstring
|
||||
_CHECKPOINT_FOR_QA = "valhalla/ldmbert-large-finetuned-squadv1"
|
||||
_QA_EXPECTED_LOSS = 0.59
|
||||
_QA_EXPECTED_OUTPUT = "' nice puppet'"
|
||||
|
||||
|
||||
LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ldm-bert",
|
||||
# See all LDMBert models at https://huggingface.co/models?filter=ldmbert
|
||||
]
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
|
||||
class LDMBertAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = head_dim
|
||||
self.inner_dim = head_dim * num_heads
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(self.inner_dim, embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class LDMBertEncoderLayer(nn.Module):
|
||||
def __init__(self, config: LDMBertConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = LDMBertAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
head_dim=config.head_dim,
|
||||
dropout=config.attention_dropout,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
|
||||
class LDMBertPreTrainedModel(PreTrainedModel):
|
||||
config_class = LDMBertConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LDMBertDecoder, LDMBertEncoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||||
dummy_inputs = {
|
||||
"attention_mask": input_ids.ne(pad_token),
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
|
||||
LDMBERT_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LDMBertConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
LDMBERT_GENERATION_EXAMPLE = r"""
|
||||
Summarization example:
|
||||
|
||||
```python
|
||||
>>> from transformers import BartTokenizer, LDMBertForConditionalGeneration
|
||||
|
||||
>>> model = LDMBertForConditionalGeneration.from_pretrained("facebook/ldmbert-large-cnn")
|
||||
>>> tokenizer = BartTokenizer.from_pretrained("facebook/ldmbert-large-cnn")
|
||||
|
||||
>>> ARTICLE_TO_SUMMARIZE = (
|
||||
... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
|
||||
... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
|
||||
... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
||||
... )
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
|
||||
|
||||
>>> # Generate Summary
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
|
||||
>>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
|
||||
```
|
||||
|
||||
Mask filling example:
|
||||
|
||||
```python
|
||||
>>> from transformers import BartTokenizer, LDMBertForConditionalGeneration
|
||||
|
||||
>>> tokenizer = BartTokenizer.from_pretrained("ldm-bert")
|
||||
>>> model = LDMBertForConditionalGeneration.from_pretrained("ldm-bert")
|
||||
|
||||
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
||||
>>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
|
||||
>>> logits = model(input_ids).logits
|
||||
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
||||
>>> probs = logits[0, masked_index].softmax(dim=0)
|
||||
>>> values, predictions = probs.topk(5)
|
||||
|
||||
>>> tokenizer.decode(predictions).split()
|
||||
['not', 'good', 'healthy', 'great', 'very']
|
||||
```
|
||||
"""
|
||||
|
||||
LDMBERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
LDMBert uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
|
||||
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For translation and summarization training, `decoder_input_ids` should be provided. If no
|
||||
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
|
||||
for denoising pre-training following the paper.
|
||||
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
If you want to change padding behavior, you should read
|
||||
[`modeling_ldmbert._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
|
||||
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
||||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
|
||||
1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
|
||||
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
|
||||
can choose to directly pass an embedded representation. This is useful if you want more control over how to
|
||||
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
||||
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
||||
input (see `past_key_values`). This is useful if you want more control over how to convert
|
||||
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||||
|
||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||||
of `inputs_embeds`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class LDMBertEncoder(LDMBertPreTrainedModel):
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||||
[`LDMBertEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: LDMBertConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: LDMBertConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
|
||||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
seq_len = input_shape[1]
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
|
||||
embed_pos = self.embed_positions(position_ids)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
if head_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class LDMBertModel(LDMBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
# logits = self.to_logits(sequence_output)
|
||||
# outputs = (logits,) + outputs[1:]
|
||||
|
||||
# if labels is not None:
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
# outputs = (loss,) + outputs
|
||||
|
||||
# if not return_dict:
|
||||
# return outputs
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
# hidden_states=outputs[1],
|
||||
# attentions=outputs[2],
|
||||
)
|
||||
@@ -1,859 +0,0 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
"""
|
||||
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
print(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
else:
|
||||
self.re_embed = n_e
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
||||
assert return_logits == False, "Only for interface compatible with Gumbel"
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
min_encodings = None
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class VQModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
double_z=True,
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
n_embed=n_embed,
|
||||
embed_dim=embed_dim,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
double_z=double_z,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
double_z=double_z,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
embed_dim,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
double_z=True,
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
embed_dim=embed_dim,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
double_z=double_z,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
double_z=double_z,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
|
||||
# Sample gaussian noise to begin loop
|
||||
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
|
||||
timestep_values = self.noise_scheduler.timestep_values
|
||||
timestep_values = self.noise_scheduler.config.timestep_values
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
|
||||
@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
|
||||
@@ -24,17 +24,11 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
|
||||
|
||||
try:
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
except:
|
||||
print("Transformers is not installed")
|
||||
pass
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
|
||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline):
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||
residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
xt = xt * y_mask
|
||||
|
||||
@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)[0]
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = torch.randn(
|
||||
|
||||
@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
||||
t_orig = warmup_time_steps[t]
|
||||
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||
|
||||
@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
||||
timesteps=timesteps,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
return self.config.timesteps
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,12 +11,40 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -37,19 +65,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
clip_sample=clip_sample,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_sample = clip_sample
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
self.betas = betas_for_alpha_bar(timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -59,51 +84,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
# def rescale_betas(self, num_timesteps):
|
||||
# # GLIDE scaling
|
||||
# if self.beta_schedule == "linear":
|
||||
# scale = self.timesteps / num_timesteps
|
||||
# self.betas = linear_beta_schedule(
|
||||
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
|
||||
# )
|
||||
# self.alphas = 1.0 - self.betas
|
||||
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
|
||||
def get_beta(self, time_step):
|
||||
return self.betas[time_step]
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_orig_t(self, t, num_inference_steps):
|
||||
if t < 0:
|
||||
return -1
|
||||
return self.timesteps // num_inference_steps * t
|
||||
|
||||
def get_variance(self, t, num_inference_steps):
|
||||
orig_t = self.get_orig_t(t, num_inference_steps)
|
||||
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
|
||||
orig_t = self.config.timesteps // num_inference_steps * t
|
||||
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
|
||||
|
||||
alpha_prod_t = self.get_alpha_prod(orig_t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
|
||||
alpha_prod_t = self.alphas_cumprod[orig_t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -124,12 +110,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get actual t and t-1
|
||||
orig_t = self.get_orig_t(t, num_inference_steps)
|
||||
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
|
||||
orig_t = self.config.timesteps // num_inference_steps * t
|
||||
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(orig_t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
|
||||
alpha_prod_t = self.alphas_cumprod[orig_t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
@@ -137,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.clip_sample:
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
@@ -158,4 +144,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return pred_prev_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
return self.config.timesteps
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,12 +11,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
|
||||
|
||||
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -43,21 +70,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance_type=variance_type,
|
||||
clip_sample=clip_sample,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_sample = clip_sample
|
||||
self.variance_type = variance_type
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
self.betas = betas_for_alpha_bar(timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -67,70 +87,48 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# self.register_buffer("betas", betas.to(torch.float32))
|
||||
# self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
|
||||
def get_beta(self, time_step):
|
||||
return self.betas[time_step]
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_variance(self, t):
|
||||
alpha_prod_t = self.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
|
||||
# hacks - were probs added for training stability
|
||||
if self.variance_type == "fixed_small":
|
||||
if self.config.variance_type == "fixed_small":
|
||||
variance = self.clip(variance, min_value=1e-20)
|
||||
elif self.variance_type == "fixed_large":
|
||||
variance = self.get_beta(t)
|
||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
elif self.config.variance_type == "fixed_small_log":
|
||||
variance = self.log(self.clip(variance, min_value=1e-20))
|
||||
elif self.config.variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, sample, t):
|
||||
def step(self, residual, sample, t, predict_epsilon=True):
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
if predict_epsilon:
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
pred_original_sample = residual
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.clip_sample:
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
|
||||
current_sample_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
|
||||
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
@@ -139,10 +137,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return pred_prev_sample
|
||||
|
||||
def forward_step(self, original_sample, noise, t):
|
||||
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5
|
||||
sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5
|
||||
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
return self.config.timesteps
|
||||
|
||||
@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def sample_noise(self, timestep):
|
||||
@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
||||
return xt
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
return len(self.config.timesteps)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,12 +11,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
|
||||
|
||||
class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -35,16 +62,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
self.betas = betas_for_alpha_bar(timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -57,55 +80,58 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at equations (12) and (13) and the Algorithm 2.
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# running values
|
||||
self.cur_residual = 0
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
self.warmup_time_steps = {}
|
||||
self.prk_time_steps = {}
|
||||
self.time_steps = {}
|
||||
self.set_prk_mode()
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
def get_prk_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.prk_time_steps:
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
def get_beta(self, time_step):
|
||||
return self.betas[time_step]
|
||||
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_warmup_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.warmup_time_steps:
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
|
||||
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
|
||||
|
||||
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
def get_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.time_steps:
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
|
||||
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
|
||||
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
|
||||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||
def set_prk_mode(self):
|
||||
self.mode = "prk"
|
||||
|
||||
t_prev = warmup_time_steps[t // 4 * 4]
|
||||
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
|
||||
def set_plms_mode(self):
|
||||
self.mode = "plms"
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mode == "prk":
|
||||
return self.step_prk(*args, **kwargs)
|
||||
if self.mode == "plms":
|
||||
return self.step_plms(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"mode {self.mode} does not exist.")
|
||||
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
|
||||
|
||||
if t % 4 == 0:
|
||||
self.cur_residual += 1 / 6 * residual
|
||||
@@ -119,33 +145,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
residual = self.cur_residual + 1 / 6 * residual
|
||||
self.cur_residual = 0
|
||||
|
||||
return self.transfer(self.cur_sample, t_prev, t_next, residual)
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = self.cur_sample if self.cur_sample is not None else sample
|
||||
|
||||
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
|
||||
|
||||
def step_plms(self, residual, sample, t, num_inference_steps):
|
||||
if len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
"in 'prk' mode for at least 12 iterations "
|
||||
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
|
||||
"for more information."
|
||||
)
|
||||
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
|
||||
t_prev = timesteps[t]
|
||||
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
t_orig = timesteps[t]
|
||||
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
self.ets.append(residual)
|
||||
|
||||
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return self.transfer(sample, t_prev, t_next, residual)
|
||||
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
|
||||
|
||||
def transfer(self, x, t, t_next, et):
|
||||
# TODO(Patrick): clean up to be compatible with numpy and give better names
|
||||
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
|
||||
alphas_cump = self.alphas_cumprod.to(x.device)
|
||||
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# alpha_prod_t -> α_t
|
||||
# alpha_prod_t_prev -> α_(t−δ)
|
||||
# beta_prod_t -> (1 - α_t)
|
||||
# beta_prod_t_prev -> (1 - α_(t−δ))
|
||||
# sample -> x_t
|
||||
# residual -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[t_orig + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
x_delta = (at_next - at) * (
|
||||
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
|
||||
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
|
||||
)
|
||||
# corresponds to (α_(t−δ) - α_t) divided by
|
||||
# denominator of x_t in formula (9) and plus 1
|
||||
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
|
||||
# sqrt(α_(t−δ)) / sqrt(α_t))
|
||||
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
|
||||
|
||||
x_next = x + x_delta
|
||||
return x_next
|
||||
# corresponds to denominator of e_θ(x_t, t) in formula (9)
|
||||
residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
|
||||
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
|
||||
) ** (0.5)
|
||||
|
||||
# full formula (9)
|
||||
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
|
||||
|
||||
return prev_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
return self.config.timesteps
|
||||
|
||||
@@ -18,30 +18,6 @@ import torch
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float32)
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
@@ -64,3 +40,13 @@ class SchedulerMixin:
|
||||
return torch.clamp(tensor, min_value, max_value)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def log(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
|
||||
if tensor_format == "np":
|
||||
return np.log(tensor)
|
||||
elif tensor_format == "pt":
|
||||
return torch.log(tensor)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
import os
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -20,8 +11,18 @@ import os
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import importlib_metadata
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
@@ -36,6 +37,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
@@ -49,3 +62,39 @@ class EntryNotFoundError(HTTPError):
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
|
||||
`pip install transformers`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def requires_backends(obj, backends):
|
||||
if not isinstance(backends, (list, tuple)):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
||||
`requires_backend` each time a user tries to access any method of that class.
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
48
src/diffusers/utils/dummy_transformers_objects.py
Normal file
48
src/diffusers/utils/dummy_transformers_objects.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GLIDESuperResUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GLIDETextToImageUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GLIDEUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class UNetGradTTSModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
GLIDE = None
|
||||
|
||||
|
||||
class GradTTS(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class LatentDiffusion(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
@@ -14,11 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import pytest
|
||||
from diffusers import (
|
||||
BDDM,
|
||||
DDIM,
|
||||
@@ -27,9 +30,12 @@ from diffusers import (
|
||||
PNDM,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
GLIDESuperResUNetModel,
|
||||
LatentDiffusion,
|
||||
PNDMScheduler,
|
||||
UNetModel,
|
||||
UNetLDMModel,
|
||||
UNetGradTTSModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
@@ -82,7 +88,108 @@ class ConfigTester(unittest.TestCase):
|
||||
assert config == new_config
|
||||
|
||||
|
||||
class ModelTesterMixin(unittest.TestCase):
|
||||
class ModelTesterMixin:
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
|
||||
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
first = model(**inputs_dict)
|
||||
second = model(**inputs_dict)
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["x"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["x", "timesteps"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_model_from_config(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# test if the model can be loaded from the config
|
||||
# and has all the expected shape
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_config(tmpdirname)
|
||||
new_model = self.model_class.from_config(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
new_model.eval()
|
||||
|
||||
# check if all paramters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
self.assertEqual(param_1.shape, param_2.shape)
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)
|
||||
output_2 = new_model(**inputs_dict)
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
def test_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict)
|
||||
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
|
||||
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
@@ -92,32 +199,289 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return (noise, time_step)
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
model.to(torch_device)
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = UNetModel.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
dummy_input = self.dummy_input
|
||||
|
||||
image = model(*dummy_input)
|
||||
new_image = new_model(*dummy_input)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 32,
|
||||
"ch_mult": (1, 2),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": (16,),
|
||||
"resolution": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
model.to(torch_device)
|
||||
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
image = model(*self.dummy_input)
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = GLIDESuperResUNetModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 6
|
||||
sizes = (32, 32)
|
||||
low_res_size = (4, 4)
|
||||
|
||||
torch_device = "cpu"
|
||||
|
||||
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
|
||||
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step, "low_res": low_res}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (6, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"attention_resolutions": (2,),
|
||||
"channel_mult": (1, 2),
|
||||
"in_channels": 6,
|
||||
"out_channels": 6,
|
||||
"model_channels": 32,
|
||||
"num_head_channels": 8,
|
||||
"num_heads_upsample": 1,
|
||||
"num_res_blocks": 2,
|
||||
"resblock_updown": True,
|
||||
"resolution": 32,
|
||||
"use_scale_shift_norm": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["x"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = GLIDESuperResUNetModel.from_pretrained(
|
||||
"fusing/glide-super-res-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, 3, 64, 64)
|
||||
low_res = torch.randn(1, 3, 4, 4)
|
||||
time_step = torch.tensor([42] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, low_res)
|
||||
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetLDMModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"image_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"model_channels": 32,
|
||||
"num_res_blocks": 2,
|
||||
"attention_resolutions": (16,),
|
||||
"channel_mult": (1, 2),
|
||||
"num_heads": 2,
|
||||
"conv_resample": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetGradTTSModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 32
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"dim": 64,
|
||||
"groups": 4,
|
||||
"dim_mults": (1, 2),
|
||||
"n_feats": 32,
|
||||
"pe_scale": 1000,
|
||||
"n_spks": 1,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.config.n_feats
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, num_features, seq_len))
|
||||
condition = torch.randn((1, num_features, seq_len))
|
||||
mask = torch.randn((1, 1, seq_len))
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, condition, mask)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
@@ -223,7 +587,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=20)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
print(image_slice.shape)
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
image = np.random.rand(batch_size, num_channels, height, width)
|
||||
sample = np.random.rand(batch_size, num_channels, height, width)
|
||||
|
||||
return image
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_image_deter(self):
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * height * width
|
||||
image = np.arange(num_elems)
|
||||
image = image.reshape(num_channels, height, width, batch_size)
|
||||
image = image / num_elems
|
||||
image = image.transpose(3, 0, 1, 2)
|
||||
sample = np.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, height, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
sample = sample.transpose(3, 0, 1, 2)
|
||||
|
||||
return image
|
||||
return sample
|
||||
|
||||
def get_scheduler_config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def dummy_model(self):
|
||||
def model(image, t, *args):
|
||||
return image * t / (t + 1)
|
||||
def model(sample, t, *args):
|
||||
return sample * t / (t + 1)
|
||||
|
||||
return model
|
||||
|
||||
@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, image, 1, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, 1, **kwargs)
|
||||
output = scheduler.step(residual, sample, 1, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, 1, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -132,34 +132,34 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
output_0 = scheduler.step(residual, image, 0, **kwargs)
|
||||
output_1 = scheduler.step(residual, image, 1, **kwargs)
|
||||
output_0 = scheduler.step(residual, sample, 0, **kwargs)
|
||||
output_1 = scheduler.step(residual, sample, 1, **kwargs)
|
||||
|
||||
self.assertEqual(output_0.shape, image.shape)
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
image_pt = torch.tensor(image)
|
||||
residual_pt = 0.1 * image_pt
|
||||
sample_pt = torch.tensor(sample)
|
||||
residual_pt = 0.1 * sample_pt
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
output = scheduler.step(residual, image, 1, **kwargs)
|
||||
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)
|
||||
output = scheduler.step(residual, sample, 1, **kwargs)
|
||||
output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
for variance in ["fixed_small", "fixed_large", "other"]:
|
||||
self.check_over_configs(variance_type=variance)
|
||||
|
||||
def test_clip_image(self):
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
@@ -219,26 +219,26 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
num_trained_timesteps = len(scheduler)
|
||||
|
||||
model = self.dummy_model()
|
||||
image = self.dummy_image_deter
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
for t in reversed(range(num_trained_timesteps)):
|
||||
# 1. predict noise residual
|
||||
residual = model(image, t)
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = scheduler.step(residual, image, t)
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(residual, sample, t)
|
||||
|
||||
if t > 0:
|
||||
noise = self.dummy_image_deter
|
||||
noise = self.dummy_sample_deter
|
||||
variance = scheduler.get_variance(t) ** (0.5) * noise
|
||||
|
||||
image = pred_prev_image + variance
|
||||
sample = pred_prev_sample + variance
|
||||
|
||||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
|
||||
assert result_sum.item() - 732.9947 < 1e-3
|
||||
assert result_mean.item() - 0.9544 < 1e-3
|
||||
assert abs(result_sum.item() - 732.9947) < 1e-2
|
||||
assert abs(result_mean.item() - 0.9544) < 1e-3
|
||||
|
||||
|
||||
class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_clip_image(self):
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
@@ -308,22 +308,170 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
image = self.dummy_image_deter
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
for t in reversed(range(num_inference_steps)):
|
||||
residual = model(image, inference_step_times[t])
|
||||
residual = model(sample, inference_step_times[t])
|
||||
|
||||
pred_prev_image = scheduler.step(residual, image, t, num_inference_steps, eta)
|
||||
pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
|
||||
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = self.dummy_image_deter
|
||||
noise = self.dummy_sample_deter
|
||||
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
|
||||
|
||||
image = pred_prev_image + variance
|
||||
sample = pred_prev_sample + variance
|
||||
|
||||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
|
||||
assert result_sum.item() - 270.6214 < 1e-3
|
||||
assert result_mean.item() - 0.3524 < 1e-3
|
||||
assert abs(result_sum.item() - 270.6214) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3524) < 1e-3
|
||||
|
||||
|
||||
class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (PNDMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs_pmls(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(timesteps=timesteps)
|
||||
|
||||
def test_timesteps_pmls(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs_pmls(timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_betas_pmls(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_schedules_pmls(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_time_indices_pmls(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward_pmls(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_steps_pmls(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_pmls_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in range(len(prk_time_steps)):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_prk(residual, sample, t, num_inference_steps)
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_plms(residual, sample, t, num_inference_steps)
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 199.1169) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2593) < 1e-3
|
||||
|
||||
@@ -20,10 +20,10 @@ import re
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_DIFFUSERS = "src/diffusers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
_re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
|
||||
@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
backends = _re_backend.findall(line)
|
||||
if len(backends) == 0:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
return backends[0]
|
||||
|
||||
|
||||
def read_init():
|
||||
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
while not lines[line_index].startswith(" else:"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
line_index += 1
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
while not lines[line_index].startswith("else:"):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
path = os.path.join(PATH_TO_DIFFUSERS, "utils")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||
for backend in dummy_files.keys()
|
||||
@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
"to fix this."
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import re
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_table.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
TRANSFORMERS_PATH = "src/diffusers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
# This is to make sure the diffusers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
"diffusers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers_module = spec.loader.load_module()
|
||||
diffusers_module = spec.loader.load_module()
|
||||
|
||||
|
||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
@@ -88,10 +88,10 @@ def _center_text(text, width):
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# Dictionary model names to config.
|
||||
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_name_to_config = {
|
||||
name: config_maping_names[code]
|
||||
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||
for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
|
||||
if code in config_maping_names
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
|
||||
@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
|
||||
tf_models = collections.defaultdict(bool)
|
||||
flax_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once).
|
||||
for attr_name in dir(transformers_module):
|
||||
# Let's lookup through all diffusers object (once).
|
||||
for attr_name in dir(diffusers_module):
|
||||
lookup_dict = None
|
||||
if attr_name.endswith("Tokenizer"):
|
||||
lookup_dict = slow_tokenizers
|
||||
|
||||
Reference in New Issue
Block a user