diff --git a/Makefile b/Makefile index f5c3573084..dad0611769 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := models tests src utils +check_dirs := tests src utils modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 99d856094d..9a6132a433 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -2,15 +2,16 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. -__version__ = "0.0.1" +__version__ = "0.0.3" from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler -from .schedulers.gaussian_ddpm import GaussianDDPMScheduler -from .schedulers.ddim import DDIMScheduler -from .schedulers.glide_ddim import GlideDDIMScheduler from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion +from .schedulers import SchedulerMixin +from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler +from .schedulers.ddim import DDIMScheduler +from .schedulers.gaussian_ddpm import GaussianDDPMScheduler +from .schedulers.glide_ddim import GlideDDIMScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 51c27e339e..ce0b7d0ea5 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -213,7 +213,7 @@ class ConfigMixin: passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: - logger.warn( + logger.warning( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index dd3c6e9e04..13a4c2efdc 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module): raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: - logger.warning( + logger.warninging( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" @@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module): else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: - logger.warning( + logger.warninging( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." @@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module): for key, shape1, shape2 in mismatched_keys ] ) - logger.warning( + logger.warninging( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 437d95da13..18d0e80e1a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM -from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_glide import GLIDE +from .pipeline_latent_diffusion import LatentDiffusion diff --git a/src/diffusers/pipelines/configuration_ldmbert.py b/src/diffusers/pipelines/configuration_ldmbert.py index a00e6ca3de..00d3ac907e 100644 --- a/src/diffusers/pipelines/configuration_ldmbert.py +++ b/src/diffusers/pipelines/configuration_ldmbert.py @@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig): scale_embedding=False, use_cache=True, pad_token_id=0, - **kwargs + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings diff --git a/src/diffusers/pipelines/modeling_vae.py b/src/diffusers/pipelines/modeling_vae.py index c7be0018cd..7b299eee5e 100644 --- a/src/diffusers/pipelines/modeling_vae.py +++ b/src/diffusers/pipelines/modeling_vae.py @@ -2,10 +2,10 @@ import math import numpy as np -import tqdm import torch import torch.nn as nn +import tqdm from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin @@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean + class AutoencoderKL(ModelMixin, ConfigMixin): def __init__( self, @@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): give_pre_end=give_pre_end, ) - self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) def encode(self, x): @@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin): else: z = posterior.mode() dec = self.decode(z) - return dec, posterior \ No newline at end of file + return dec, posterior diff --git a/src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py b/src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py index a00e6ca3de..00d3ac907e 100644 --- a/src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py +++ b/src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py @@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig): scale_embedding=False, use_cache=True, pad_token_id=0, - **kwargs + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings diff --git a/src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py b/src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py index 9dd778e51d..cf4e90fe7a 100644 --- a/src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py +++ b/src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py @@ -1,12 +1,14 @@ -import tqdm import torch +import tqdm from diffusers import DiffusionPipeline +from .configuration_ldmbert import LDMBertConfig # NOQA +from .modeling_ldmbert import LDMBertModel # NOQA + # add these relative imports here, so we can load from hub -from .modeling_vae import AutoencoderKL # NOQA -from .configuration_ldmbert import LDMBertConfig # NOQA -from .modeling_ldmbert import LDMBertModel # NOQA +from .modeling_vae import AutoencoderKL # NOQA + class LatentDiffusion(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): @@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline): self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) @torch.no_grad() - def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50): + def __call__( + self, + prompt, + batch_size=1, + generator=None, + torch_device=None, + eta=0.0, + guidance_scale=1.0, + num_inference_steps=50, + ): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: @@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline): self.unet.to(torch_device) self.vqvae.to(torch_device) self.bert.to(torch_device) - + # get unconditional embeddings for classifier free guidence if guidance_scale != 1.0: - uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device) + uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( + torch_device + ) uncond_embeddings = self.bert(uncond_input.input_ids)[0] - + # get text embedding - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_embedding = self.bert(text_input.input_ids)[0] - + num_trained_timesteps = self.noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) @@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline): device=torch_device, generator=generator, ) - + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline): timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) else: # for classifier free guidance, we need to do two forward passes - # here we concanate embedding and unconditioned embedding in a single batch + # here we concanate embedding and unconditioned embedding in a single batch # to avoid doing two forward passes image_in = torch.cat([image] * 2) context = torch.cat([uncond_embeddings, text_embedding]) @@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline): # 1. predict noise residual pred_noise_t = self.unet(image_in, timesteps, context=context) - + # perform guidance if guidance_scale != 1.0: pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) - + # 2. predict previous mean of image x_t-1 pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) @@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline): image = pred_prev_image + variance # scale and decode image with vae - image = 1 / 0.18215 * image + image = 1 / 0.18215 * image image = self.vqvae.decode(image) - image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) return image diff --git a/src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py b/src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py index 4dbe74ada2..e36ff6455d 100644 --- a/src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py +++ b/src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py @@ -43,6 +43,7 @@ from transformers.utils import ( logging, replace_return_docstrings, ) + from .configuration_ldmbert import LDMBertConfig @@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel): super().__init__(config) self.model = LDMBertEncoder(config) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) - + def forward( self, input_ids=None, @@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel): output_attentions=None, output_hidden_states=None, return_dict=None, - ): + ): outputs = self.model( input_ids, @@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel): sequence_output = outputs[0] # logits = self.to_logits(sequence_output) # outputs = (logits,) + outputs[1:] - + # if labels is not None: # loss_fct = CrossEntropyLoss() # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) # outputs = (loss,) + outputs - + # if not return_dict: # return outputs - + return BaseModelOutput( last_hidden_state=sequence_output, # hidden_states=outputs[1], diff --git a/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py b/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py index c7be0018cd..7b299eee5e 100644 --- a/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py +++ b/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py @@ -2,10 +2,10 @@ import math import numpy as np -import tqdm import torch import torch.nn as nn +import tqdm from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin @@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean + class AutoencoderKL(ModelMixin, ConfigMixin): def __init__( self, @@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): give_pre_end=give_pre_end, ) - self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) def encode(self, x): @@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin): else: z = posterior.mode() dec = self.decode(z) - return dec, posterior \ No newline at end of file + return dec, posterior diff --git a/src/diffusers/pipelines/pipeline_ddim.py b/src/diffusers/pipelines/pipeline_ddim.py index 530945238b..f7bdfc807e 100644 --- a/src/diffusers/pipelines/pipeline_ddim.py +++ b/src/diffusers/pipelines/pipeline_ddim.py @@ -17,12 +17,14 @@ import torch import tqdm + from ..pipeline_utils import DiffusionPipeline class DDIM(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): @@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline): self.unet.to(torch_device) # Sample gaussian noise to begin loop - image = self.noise_scheduler.sample_noise( + image = torch.randn( (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), - device=torch_device, generator=generator, ) + image = image.to(torch_device) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(image.device) variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 diff --git a/src/diffusers/pipelines/pipeline_ddpm.py b/src/diffusers/pipelines/pipeline_ddpm.py index cc1fad5794..ebcce77337 100644 --- a/src/diffusers/pipelines/pipeline_ddpm.py +++ b/src/diffusers/pipelines/pipeline_ddpm.py @@ -17,12 +17,14 @@ import torch import tqdm + from ..pipeline_utils import DiffusionPipeline class DDPM(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None): @@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline): self.unet.to(torch_device) # Sample gaussian noise to begin loop - image = self.noise_scheduler.sample_noise( + image = torch.randn( (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), - device=torch_device, generator=generator, ) + image = image.to(torch_device) num_prediction_steps = len(self.noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): @@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline): # 3. optionally sample variance variance = 0 if t > 0: - noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(image.device) variance = self.noise_scheduler.get_variance(t).sqrt() * noise # 4. set current image to prev_image: x_t -> x_t-1 diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index d9d19899c6..1b483023e1 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -24,10 +24,6 @@ import torch.utils.checkpoint from torch import nn import tqdm -from ..pipeline_utils import DiffusionPipeline -from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel -from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler - from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -40,6 +36,10 @@ from transformers.utils import ( replace_return_docstrings, ) +from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel +from ..pipeline_utils import DiffusionPipeline +from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler + ##################### # START OF THE CLIP MODEL COPY-PASTE (with a modified attention module) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 2d57f88968..20153aea40 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -2,13 +2,14 @@ import math import numpy as np -import tqdm import torch import torch.nn as nn -from ..pipeline_utils import DiffusionPipeline +import tqdm + from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from ..pipeline_utils import DiffusionPipeline def get_timestep_embedding(timesteps, embedding_dim): @@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean + class AutoencoderKL(ModelMixin, ConfigMixin): def __init__( self, @@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): give_pre_end=give_pre_end, ) - self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) def encode(self, x): @@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin): class LatentDiffusion(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) @torch.no_grad() - def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50): + def __call__( + self, + prompt, + batch_size=1, + generator=None, + torch_device=None, + eta=0.0, + guidance_scale=1.0, + num_inference_steps=50, + ): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: @@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline): self.unet.to(torch_device) self.vqvae.to(torch_device) self.bert.to(torch_device) - + # get unconditional embeddings for classifier free guidence if guidance_scale != 1.0: - uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device) + uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( + torch_device + ) uncond_embeddings = self.bert(uncond_input.input_ids)[0] - + # get text embedding - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_embedding = self.bert(text_input.input_ids)[0] - + num_trained_timesteps = self.noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - image = self.noise_scheduler.sample_noise( + image = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), - device=torch_device, generator=generator, ) - + image = image.to(torch_device) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline): timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) else: # for classifier free guidance, we need to do two forward passes - # here we concanate embedding and unconditioned embedding in a single batch + # here we concanate embedding and unconditioned embedding in a single batch # to avoid doing two forward passes image_in = torch.cat([image] * 2) context = torch.cat([uncond_embeddings, text_embedding]) @@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline): # 1. predict noise residual pred_noise_t = self.unet(image_in, timesteps, context=context) - + # perform guidance if guidance_scale != 1.0: pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) - + # 2. get actual t and t-1 train_step = inference_step_times[t] prev_train_step = inference_step_times[t - 1] if t > 0 else -1 @@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline): # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image # Note: eta = 1.0 essentially corresponds to DDPM if eta > 0.0: - noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn( + (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + generator=generator, + ) + noise = noise.to(torch_device) prev_image = pred_prev_image + std_dev_t * noise else: prev_image = pred_prev_image @@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline): image = prev_image # scale and decode image with vae - image = 1 / 0.18215 * image + image = 1 / 0.18215 * image image = self.vqvae.decode(image) - image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) return image diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 460851bca3..6eb952234f 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from .classifier_free_guidance import ClassifierFreeGuidanceScheduler -from .gaussian_ddpm import GaussianDDPMScheduler from .ddim import DDIMScheduler +from .gaussian_ddpm import GaussianDDPMScheduler from .glide_ddim import GlideDDIMScheduler +from .schedulers_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/ddim.py b/src/diffusers/schedulers/ddim.py index 42b8f0d029..eda4db4501 100644 --- a/src/diffusers/schedulers/ddim.py +++ b/src/diffusers/schedulers/ddim.py @@ -13,20 +13,13 @@ # limitations under the License. import math -import torch -from torch import nn +import numpy as np from ..configuration_utils import ConfigMixin -from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule +from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule -SAMPLING_CONFIG_NAME = "scheduler_config.json" - - -class DDIMScheduler(nn.Module, ConfigMixin): - - config_name = SAMPLING_CONFIG_NAME - +class DDIMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, timesteps=1000, @@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin): beta_end=0.02, beta_schedule="linear", clip_predicted_image=True, + tensor_format="np", ): super().__init__() self.register( @@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin): self.clip_image = clip_predicted_image if beta_schedule == "linear": - betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - betas = betas_for_alpha_bar( + self.betas = betas_for_alpha_bar( timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) - self.register_buffer("betas", betas.to(torch.float32)) - self.register_buffer("alphas", alphas.to(torch.float32)) - self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + self.set_format(tensor_format=tensor_format) -# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - # TODO(PVP) - check how much of these is actually necessary! - # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... - # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 -# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) -# if variance_type == "fixed_small": -# log_variance = torch.log(variance.clamp(min=1e-20)) -# elif variance_type == "fixed_large": -# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) -# -# -# self.register_buffer("log_variance", log_variance.to(torch.float32)) + # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + # TODO(PVP) - check how much of these is actually necessary! + # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... + # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 + # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + # if variance_type == "fixed_small": + # log_variance = torch.log(variance.clamp(min=1e-20)) + # elif variance_type == "fixed_large": + # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) + # + # + # self.register_buffer("log_variance", log_variance.to(torch.float32)) def get_alpha(self, time_step): return self.alphas[time_step] @@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin): def get_alpha_prod(self, time_step): if time_step < 0: - return torch.tensor(1.0) + return self.one return self.alphas_cumprod[time_step] def get_orig_t(self, t, num_inference_steps): @@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin): # 3. compute predicted original image from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt() + pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if self.clip_image: - pred_original_image = torch.clamp(pred_original_image, -1, 1) + pred_original_image = self.clip(pred_original_image, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self.get_variance(t, num_inference_steps) - std_dev_t = eta * variance.sqrt() + std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * residual + pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction + pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction return pred_prev_image - def sample_noise(self, shape, device, generator=None): - # always sample on CPU to be deterministic - return torch.randn(shape, generator=generator).to(device) - def __len__(self): return self.timesteps diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index c3e1b1fad1..f64dcaf32a 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -13,19 +13,13 @@ # limitations under the License. import math -import torch -from torch import nn +import numpy as np from ..configuration_utils import ConfigMixin -from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule +from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule -SAMPLING_CONFIG_NAME = "scheduler_config.json" - - -class GaussianDDPMScheduler(nn.Module, ConfigMixin): - config_name = SAMPLING_CONFIG_NAME - +class GaussianDDPMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, timesteps=1000, @@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): beta_schedule="linear", variance_type="fixed_small", clip_predicted_image=True, + tensor_format="np", ): super().__init__() self.register( @@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): self.variance_type = variance_type if beta_schedule == "linear": - betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - betas = betas_for_alpha_bar( + self.betas = betas_for_alpha_bar( timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) - self.register_buffer("betas", betas.to(torch.float32)) - self.register_buffer("alphas", alphas.to(torch.float32)) - self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + self.set_format(tensor_format=tensor_format) -# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - # TODO(PVP) - check how much of these is actually necessary! - # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... - # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 -# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) -# if variance_type == "fixed_small": -# log_variance = torch.log(variance.clamp(min=1e-20)) -# elif variance_type == "fixed_large": -# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) -# -# -# self.register_buffer("log_variance", log_variance.to(torch.float32)) + # self.register_buffer("betas", betas.to(torch.float32)) + # self.register_buffer("alphas", alphas.to(torch.float32)) + # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + + # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + # TODO(PVP) - check how much of these is actually necessary! + # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... + # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 + # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + # if variance_type == "fixed_small": + # log_variance = torch.log(variance.clamp(min=1e-20)) + # elif variance_type == "fixed_large": + # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) + # + # + # self.register_buffer("log_variance", log_variance.to(torch.float32)) def get_alpha(self, time_step): return self.alphas[time_step] @@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): def get_alpha_prod(self, time_step): if time_step < 0: - return torch.tensor(1.0) + return self.one return self.alphas_cumprod[time_step] def get_variance(self, t): @@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous image # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image - variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)) + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) # hacks - were probs added for training stability if self.variance_type == "fixed_small": - variance = variance.clamp(min=1e-20) + variance = self.clip(variance, min_value=1e-20) elif self.variance_type == "fixed_large": variance = self.get_beta(t) @@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): # 2. compute predicted original image from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt() + pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) # 3. Clip "predicted x_0" if self.clip_predicted_image: - pred_original_image = torch.clamp(pred_original_image, -1, 1) + pred_original_image = self.clip(pred_original_image, -1, 1) # 4. Compute coefficients for pred_original_image x_0 and current image x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.get_beta(t)) / beta_prod_t - current_image_coeff = self.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t + pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t + current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous image µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): return pred_prev_image - def sample_noise(self, shape, device, generator=None): - # always sample on CPU to be deterministic - return torch.randn(shape, generator=generator).to(device) - def __len__(self): return self.timesteps diff --git a/src/diffusers/schedulers/schedulers_utils.py b/src/diffusers/schedulers/schedulers_utils.py index 582adfd07f..e2e73691cd 100644 --- a/src/diffusers/schedulers/schedulers_utils.py +++ b/src/diffusers/schedulers/schedulers_utils.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import torch +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + def linear_beta_schedule(timesteps, beta_start, beta_end): - return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): @@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float64) + return np.array(betas, dtype=np.float32) + + +class SchedulerMixin: + + config_name = SCHEDULER_CONFIG_NAME + + def set_format(self, tensor_format="pt"): + self.tensor_format = tensor_format + if tensor_format == "pt": + for key, value in vars(self).items(): + if isinstance(value, np.ndarray): + setattr(self, key, torch.from_numpy(value)) + + return self + + def clip(self, tensor, min_value=None, max_value=None): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.clip(tensor, min_value, max_value) + elif tensor_format == "pt": + return torch.clamp(tensor, min_value, max_value) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 867ed4aedf..13f6332a94 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -1,9 +1,10 @@ import os import random import unittest -import torch from distutils.util import strtobool +import torch + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 09abb5706a..bbd90894b5 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -270,7 +270,7 @@ def reset_format() -> None: def warning_advice(self, *args, **kwargs): """ - This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this + This method is identical to `logger.warninging()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed """ no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 40ed0b5da4..6b84a728d1 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,11 +19,10 @@ import unittest import torch -from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler -from diffusers import DDIM, DDPM, LatentDiffusion +from diffusers import DDIM, DDPM, DDIMScheduler, GaussianDDPMScheduler, LatentDiffusion, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.testing_utils import floats_tensor, torch_device, slow +from diffusers.testing_utils import floats_tensor, slow, torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase): unet = UNetModel.from_pretrained(model_id) noise_scheduler = GaussianDDPMScheduler.from_config(model_id) + noise_scheduler = noise_scheduler.set_format("pt") ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler) image = ddpm(generator=generator) @@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase): model_id = "fusing/ddpm-cifar10" unet = UNetModel.from_pretrained(model_id) - noise_scheduler = DDIMScheduler() + noise_scheduler = DDIMScheduler(tensor_format="pt") ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler) image = ddim(generator=generator, eta=0.0) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 20943aa525..4b4540a286 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -14,12 +14,13 @@ # limitations under the License. -import torch -import numpy as np -import unittest import tempfile +import unittest -from diffusers import GaussianDDPMScheduler, DDIMScheduler +import numpy as np +import torch + +from diffusers import DDIMScheduler, GaussianDDPMScheduler torch.backends.cuda.matmul.allow_tf32 = False @@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase): image = np.random.rand(batch_size, num_channels, height, width) - return torch.tensor(image) + return image @property def dummy_image_deter(self): @@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase): image = image / num_elems image = image.transpose(3, 0, 1, 2) - return torch.tensor(image) + return image def get_scheduler_config(self): raise NotImplementedError @@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase): output = scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, image, time_step, **kwargs) - assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" def check_over_forward(self, time_step=0, **forward_kwargs): kwargs = dict(self.forward_default_kwargs) @@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase): output = scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, image, time_step, **kwargs) - assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" def test_from_pretrained_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) @@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase): output = scheduler.step(residual, image, 1, **kwargs) new_output = new_scheduler.step(residual, image, 1, **kwargs) - assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase): self.assertEqual(output_0.shape, image.shape) self.assertEqual(output_0.shape, output_1.shape) + def test_pytorch_equal_numpy(self): + kwargs = dict(self.forward_default_kwargs) + + for scheduler_class in self.scheduler_classes: + image = self.dummy_image + residual = 0.1 * image + + image_pt = torch.tensor(image) + residual_pt = 0.1 * image_pt + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + + output = scheduler.step(residual, image, 1, **kwargs) + output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs) + + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical" + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (GaussianDDPMScheduler,) @@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): "beta_end": 0.02, "beta_schedule": "linear", "variance_type": "fixed_small", - "clip_predicted_image": True + "clip_predicted_image": True, } config.update(**kwargs) @@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - assert (scheduler.get_variance(0) - 0.0).abs().sum() < 1e-5 - assert (scheduler.get_variance(487) - 0.00979).abs().sum() < 1e-5 - assert (scheduler.get_variance(999) - 0.02).abs().sum() < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest): if t > 0: noise = self.dummy_image_deter - variance = scheduler.get_variance(t).sqrt() * noise + variance = scheduler.get_variance(t) ** (0.5) * noise image = pred_prev_image + variance - result_sum = image.abs().sum() - result_mean = image.abs().mean() + result_sum = np.sum(np.abs(image)) + result_mean = np.mean(np.abs(image)) assert result_sum.item() - 732.9947 < 1e-3 assert result_mean.item() - 0.9544 < 1e-3 @@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "clip_predicted_image": True + "clip_predicted_image": True, } config.update(**kwargs) @@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - assert (scheduler.get_variance(0, 50) - 0.0).abs().sum() < 1e-5 - assert (scheduler.get_variance(21, 50) - 0.14771).abs().sum() < 1e-5 - assert (scheduler.get_variance(49, 50) - 0.32460).abs().sum() < 1e-5 - assert (scheduler.get_variance(0, 1000) - 0.0).abs().sum() < 1e-5 - assert (scheduler.get_variance(487, 1000) - 0.00979).abs().sum() < 1e-5 - assert (scheduler.get_variance(999, 1000) - 0.02).abs().sum() < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5 + assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest): variance = 0 if eta > 0: noise = self.dummy_image_deter - variance = scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise + variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise image = pred_prev_image + variance - result_sum = image.abs().sum() - result_mean = image.abs().mean() + result_sum = np.sum(np.abs(image)) + result_mean = np.mean(np.abs(image)) assert result_sum.item() - 270.6214 < 1e-3 assert result_mean.item() - 0.3524 < 1e-3