mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
finish refactor
This commit is contained in:
2
Makefile
2
Makefile
@@ -3,7 +3,7 @@
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := models tests src utils
|
||||
check_dirs := tests src utils
|
||||
|
||||
modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.0.1"
|
||||
__version__ = "0.0.3"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .schedulers.ddim import DDIMScheduler
|
||||
from .schedulers.glide_ddim import GlideDDIMScheduler
|
||||
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
|
||||
from .schedulers import SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .schedulers.ddim import DDIMScheduler
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .schedulers.glide_ddim import GlideDDIMScheduler
|
||||
|
||||
@@ -213,7 +213,7 @@ class ConfigMixin:
|
||||
|
||||
passed_keys = set(init_dict.keys())
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||
)
|
||||
|
||||
|
||||
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warninging(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warninging(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
logger.warninging(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_glide import GLIDE
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
|
||||
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
return dec, posterior
|
||||
|
||||
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import tqdm
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from .configuration_ldmbert import LDMBertConfig # NOQA
|
||||
from .modeling_ldmbert import LDMBertModel # NOQA
|
||||
|
||||
# add these relative imports here, so we can load from hub
|
||||
from .modeling_vae import AutoencoderKL # NOQA
|
||||
from .configuration_ldmbert import LDMBertConfig # NOQA
|
||||
from .modeling_ldmbert import LDMBertModel # NOQA
|
||||
from .modeling_vae import AutoencoderKL # NOQA
|
||||
|
||||
|
||||
class LatentDiffusion(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||
@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
|
||||
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)[0]
|
||||
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
else:
|
||||
# for classifier free guidance, we need to do two forward passes
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embedding])
|
||||
@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
|
||||
@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
|
||||
@@ -43,6 +43,7 @@ from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from .configuration_ldmbert import LDMBertConfig
|
||||
|
||||
|
||||
@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
):
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
sequence_output = outputs[0]
|
||||
# logits = self.to_logits(sequence_output)
|
||||
# outputs = (logits,) + outputs[1:]
|
||||
|
||||
|
||||
# if labels is not None:
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
# outputs = (loss,) + outputs
|
||||
|
||||
|
||||
# if not return_dict:
|
||||
# return outputs
|
||||
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
# hidden_states=outputs[1],
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
return dec, posterior
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDIM(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
|
||||
@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline):
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None):
|
||||
@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline):
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
|
||||
@@ -24,10 +24,6 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
|
||||
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
@@ -40,6 +36,10 @@ from transformers.utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
|
||||
|
||||
|
||||
#####################
|
||||
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
class LatentDiffusion(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
|
||||
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)[0]
|
||||
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
image = image.to(torch_device)
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
else:
|
||||
# for classifier free guidance, we need to do two forward passes
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embedding])
|
||||
@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
|
||||
|
||||
# 2. get actual t and t-1
|
||||
train_step = inference_step_times[t]
|
||||
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||
@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
|
||||
# Note: eta = 1.0 essentially corresponds to DDPM
|
||||
if eta > 0.0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
noise = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
generator=generator,
|
||||
)
|
||||
noise = noise.to(torch_device)
|
||||
prev_image = pred_prev_image + std_dev_t * noise
|
||||
else:
|
||||
prev_image = pred_prev_image
|
||||
@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
image = prev_image
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .ddim import DDIMScheduler
|
||||
from .gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .glide_ddim import GlideDDIMScheduler
|
||||
from .schedulers_utils import SchedulerMixin
|
||||
|
||||
@@ -13,20 +13,13 @@
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||
from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
self.clip_image = clip_predicted_image
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
self.register_buffer("betas", betas.to(torch.float32))
|
||||
self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return torch.tensor(1.0)
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_orig_t(self, t, num_inference_steps):
|
||||
@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
# 3. compute predicted original image from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
|
||||
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.clip_image:
|
||||
pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
pred_original_image = self.clip(pred_original_image, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self.get_variance(t, num_inference_steps)
|
||||
std_dev_t = eta * variance.sqrt()
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * residual
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
|
||||
pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction
|
||||
|
||||
return pred_prev_image
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
|
||||
@@ -13,19 +13,13 @@
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||
from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
class GaussianDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
beta_schedule="linear",
|
||||
variance_type="fixed_small",
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
self.variance_type = variance_type
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
self.register_buffer("betas", betas.to(torch.float32))
|
||||
self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
# self.register_buffer("betas", betas.to(torch.float32))
|
||||
# self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return torch.tensor(1.0)
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_variance(self, t):
|
||||
@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous image
|
||||
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
|
||||
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t))
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
|
||||
|
||||
# hacks - were probs added for training stability
|
||||
if self.variance_type == "fixed_small":
|
||||
variance = variance.clamp(min=1e-20)
|
||||
variance = self.clip(variance, min_value=1e-20)
|
||||
elif self.variance_type == "fixed_large":
|
||||
variance = self.get_beta(t)
|
||||
|
||||
@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
# 2. compute predicted original image from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
|
||||
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.clip_predicted_image:
|
||||
pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
pred_original_image = self.clip(pred_original_image, -1, 1)
|
||||
|
||||
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.get_beta(t)) / beta_prod_t
|
||||
current_image_coeff = self.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
|
||||
pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
|
||||
current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous image µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
return pred_prev_image
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
|
||||
@@ -11,11 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
return np.array(betas, dtype=np.float32)
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
self.tensor_format = tensor_format
|
||||
if tensor_format == "pt":
|
||||
for key, value in vars(self).items():
|
||||
if isinstance(value, np.ndarray):
|
||||
setattr(self, key, torch.from_numpy(value))
|
||||
|
||||
return self
|
||||
|
||||
def clip(self, tensor, min_value=None, max_value=None):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
|
||||
if tensor_format == "np":
|
||||
return np.clip(tensor, min_value, max_value)
|
||||
elif tensor_format == "pt":
|
||||
return torch.clamp(tensor, min_value, max_value)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
import torch
|
||||
from distutils.util import strtobool
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -270,7 +270,7 @@ def reset_format() -> None:
|
||||
|
||||
def warning_advice(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
This method is identical to `logger.warninging()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
warning will not be printed
|
||||
"""
|
||||
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
|
||||
|
||||
@@ -19,11 +19,10 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler
|
||||
from diffusers import DDIM, DDPM, LatentDiffusion
|
||||
from diffusers import DDIM, DDPM, DDIMScheduler, GaussianDDPMScheduler, LatentDiffusion, UNetModel
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, torch_device, slow
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = GaussianDDPMScheduler.from_config(model_id)
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = ddpm(generator=generator)
|
||||
@@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = DDIMScheduler()
|
||||
noise_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = ddim(generator=generator, eta=0.0)
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, DDIMScheduler
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, GaussianDDPMScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
image = np.random.rand(batch_size, num_channels, height, width)
|
||||
|
||||
return torch.tensor(image)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_image_deter(self):
|
||||
@@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
image = image / num_elems
|
||||
image = image.transpose(3, 0, 1, 2)
|
||||
|
||||
return torch.tensor(image)
|
||||
return image
|
||||
|
||||
def get_scheduler_config(self):
|
||||
raise NotImplementedError
|
||||
@@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
|
||||
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical"
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
@@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
|
||||
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical"
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
@@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
output = scheduler.step(residual, image, 1, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, 1, **kwargs)
|
||||
|
||||
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical"
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
@@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
self.assertEqual(output_0.shape, image.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
|
||||
image_pt = torch.tensor(image)
|
||||
residual_pt = 0.1 * image_pt
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
output = scheduler.step(residual, image, 1, **kwargs)
|
||||
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (GaussianDDPMScheduler,)
|
||||
@@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"variance_type": "fixed_small",
|
||||
"clip_predicted_image": True
|
||||
"clip_predicted_image": True,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert (scheduler.get_variance(0) - 0.0).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(487) - 0.00979).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(999) - 0.02).abs().sum() < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
if t > 0:
|
||||
noise = self.dummy_image_deter
|
||||
variance = scheduler.get_variance(t).sqrt() * noise
|
||||
variance = scheduler.get_variance(t) ** (0.5) * noise
|
||||
|
||||
image = pred_prev_image + variance
|
||||
|
||||
result_sum = image.abs().sum()
|
||||
result_mean = image.abs().mean()
|
||||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
|
||||
assert result_sum.item() - 732.9947 < 1e-3
|
||||
assert result_mean.item() - 0.9544 < 1e-3
|
||||
@@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"clip_predicted_image": True
|
||||
"clip_predicted_image": True,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert (scheduler.get_variance(0, 50) - 0.0).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(21, 50) - 0.14771).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(49, 50) - 0.32460).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(0, 1000) - 0.0).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(487, 1000) - 0.00979).abs().sum() < 1e-5
|
||||
assert (scheduler.get_variance(999, 1000) - 0.02).abs().sum() < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = self.dummy_image_deter
|
||||
variance = scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
|
||||
|
||||
image = pred_prev_image + variance
|
||||
|
||||
result_sum = image.abs().sum()
|
||||
result_mean = image.abs().mean()
|
||||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
|
||||
assert result_sum.item() - 270.6214 < 1e-3
|
||||
assert result_mean.item() - 0.3524 < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user