From dc6324d44bc189a0bf63018145617a736e7a38ff Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 9 Jun 2022 10:53:53 +0200 Subject: [PATCH] end-to-end glide pipeline with DDIM scheduler for upscaling --- models/vision/glide/convert_weights.py | 10 +- models/vision/glide/modeling_glide.py | 114 +++++++++--- models/vision/glide/run_glide.py | 1 - src/diffusers/__init__.py | 1 + src/diffusers/models/unet_glide.py | 173 ++++++++++++++---- src/diffusers/pipeline_utils.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../schedulers/{ddim.py => glide_ddim.py} | 24 +-- 8 files changed, 238 insertions(+), 87 deletions(-) rename src/diffusers/schedulers/{ddim.py => glide_ddim.py} (85%) diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 3609008681..5bcc68ffca 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -1,7 +1,7 @@ import torch from torch import nn -from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from modeling_glide import GLIDE from transformers import CLIPTextConfig, GPT2Tokenizer @@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule=" ### Convert the Super-Resolution UNet # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt -state_dict = torch.load("upsample.pt", map_location="cpu") +ups_state_dict = torch.load("upsample.pt", map_location="cpu") superres_model = GLIDESuperResUNetModel( in_channels=6, @@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel( resblock_updown=True, ) -superres_model.load_state_dict(state_dict) +superres_model.load_state_dict(ups_state_dict, strict=False) -upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") +upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, - upscale_unet=superres_model, upscale_noise_scheduler=scheduler) + upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler) glide.save_pretrained("./glide-base") diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index 7341dc0ad9..9ccb662564 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -18,7 +18,7 @@ import numpy as np import torch import tqdm -from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from transformers import GPT2Tokenizer @@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): class GLIDE(DiffusionPipeline): def __init__( self, - unet: GLIDETextToImageUNetModel, - noise_scheduler: ClassifierFreeGuidanceScheduler, + text_unet: GLIDETextToImageUNetModel, + text_noise_scheduler: ClassifierFreeGuidanceScheduler, text_encoder: CLIPTextModel, tokenizer: GPT2Tokenizer, + upscale_unet: GLIDESuperResUNetModel, + upscale_noise_scheduler: GlideDDIMScheduler ): super().__init__() self.register_modules( - unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer + text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, + upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler ) - def q_posterior_mean_variance(self, x_start, x_t, t): + def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior: @@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline): """ assert x_start.shape == x_t.shape posterior_mean = ( - _extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t ) - posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape) + posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( - self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape + scheduler.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] @@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline): ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None): + def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. @@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline): - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ - if model_kwargs is None: - model_kwargs = {} B, C = x.shape[:2] assert t.shape == (B,) - model_output = model(x, t, transformer_out) + if transformer_out is None: + # super-res model + model_output = model(x, t, low_res) + else: + # text2image model + model_output = model(x, t, transformer_out) assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = torch.split(model_output, C, dim=1) - min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape) - max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape) + min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log model_variance = torch.exp(model_log_variance) - pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output) if clip_denoised: pred_xstart = pred_xstart.clamp(-1, 1) - model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t) assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape return model_mean, model_variance, model_log_variance, pred_xstart - def _predict_xstart_from_eps(self, x_t, t, eps): + def _predict_xstart_from_eps(self, scheduler, x_t, t, eps): assert x_t.shape == eps.shape return ( - _extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart): + return ( + _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + @torch.no_grad() def __call__(self, prompt, generator=None, torch_device=None): torch_device = "cuda" if torch.cuda.is_available() else "cpu" - self.unet.to(torch_device) + self.text_unet.to(torch_device) self.text_encoder.to(torch_device) + self.upscale_unet.to(torch_device) # Create a classifier-free guidance sampling function guidance_scale = 3.0 - def model_fn(x_t, ts, transformer_out, **kwargs): + def text_model_fn(x_t, ts, transformer_out, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) - model_out = self.unet(combined, ts, transformer_out, **kwargs) + model_out = self.text_unet(combined, ts, transformer_out, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) @@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline): # 1. Sample gaussian noise batch_size = 2 # second image is empty for classifier-free guidance - image = self.noise_scheduler.sample_noise( - (batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator + image = self.text_noise_scheduler.sample_noise( + (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator ) # 2. Encode tokens @@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline): attention_mask = inputs["attention_mask"].to(torch_device) transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state - num_timesteps = len(self.noise_scheduler) + # 3. Run the text2image generation step + num_timesteps = len(self.text_noise_scheduler) for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): t = torch.tensor([i] * image.shape[0], device=torch_device) - mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out) - noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) + mean, variance, log_variance, pred_xstart = self.p_mean_variance( + text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out + ) + noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise + # 4. Run the upscaling step + batch_size = 1 + image = image[:1] + low_res = ((image + 1) * 127.5).round() / 127.5 - 1 + eta = 0.0 + + # Tune this parameter to control the sharpness of 256x256 images. + # A value of 1.0 is sharper, but sometimes results in grainy artifacts. + upsample_temp = 0.997 + + image = self.upscale_noise_scheduler.sample_noise( + (batch_size, 3, 256, 256), device=torch_device, generator=generator + ) * upsample_temp + + num_timesteps = len(self.upscale_noise_scheduler) + for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): + # i) define coefficients for time step t + clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) + clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) + image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( + self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) + clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( + t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) + + # ii) predict noise residual + time_input = torch.tensor([t] * image.shape[0], device=torch_device) + model_output = self.upscale_unet(image, time_input, low_res) + noise_residual, pred_variance = torch.split(model_output, 3, dim=1) + + # iii) compute predicted image from residual + # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison + pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual + pred_mean = torch.clamp(pred_mean, -1, 1) + prev_image = clipped_coeff * pred_mean + image_coeff * image + + # iv) sample variance + prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, + generator=generator) + + # v) sample x_{t-1} ~ N(prev_image, prev_variance) + sampled_prev_image = prev_image + prev_variance + image = sampled_prev_image + image = image[0].permute(1, 2, 0) return image diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index d63d620b9e..7648b39f36 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True generator = torch.Generator() generator = generator.manual_seed(0) -# 1. Load models pipeline = GLIDE.from_pretrained("fusing/glide-base") img = pipeline("a pencil sketch of a corgi", generator) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ce99498310..c9285df3e9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,3 +13,4 @@ from .models.vqvae import VQModel from .pipeline_utils import DiffusionPipeline from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler +from .schedulers.glide_ddim import GlideDDIMScheduler diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index b9b6afd072..24ef868bba 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): def __init__( self, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): transformer_dim=None, ): super().__init__() - self.register( - in_channels=in_channels, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - ) if num_heads_upsample == -1: num_heads_upsample = num_heads @@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps, y=None): + def forward(self, x, timesteps): """ Apply the model to an input batch. @@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb) @@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): Expects an extra kwarg `low_res` to condition on a low-resolution image. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + transformer_dim=512 + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + transformer_dim=transformer_dim + ) + self.register( + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + transformer_dim=transformer_dim + ) - self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4) + self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) def forward(self, x, timesteps, transformer_out=None): hs = [] @@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): Expects an extra kwarg `low_res` to condition on a low-resolution image. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + ) + self.register( + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + ) - def forward(self, x, timesteps, low_res=None, **kwargs): + def forward(self, x, timesteps, low_res=None): _, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") x = torch.cat([x, upsampled], dim=1) - return super().forward(x, timesteps, **kwargs) \ No newline at end of file + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + h = x + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb) + + return self.out(h) \ No newline at end of file diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ccc688c37c..db80218060 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -39,6 +39,7 @@ LOADABLE_CLASSES = { "CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers "GaussianDDPMScheduler": ["save_config", "from_config"], "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], + "GlideDDIMScheduler": ["save_config", "from_config"], }, "transformers": { "GPT2Tokenizer": ["save_pretrained", "from_pretrained"], diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 7311088ccc..b35571f75b 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -18,3 +18,4 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .gaussian_ddpm import GaussianDDPMScheduler +from .glide_ddim import GlideDDIMScheduler diff --git a/src/diffusers/schedulers/ddim.py b/src/diffusers/schedulers/glide_ddim.py similarity index 85% rename from src/diffusers/schedulers/ddim.py rename to src/diffusers/schedulers/glide_ddim.py index 0bcf59d263..91f62ea356 100644 --- a/src/diffusers/schedulers/ddim.py +++ b/src/diffusers/schedulers/glide_ddim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -import math +import numpy as np from torch import nn from ..configuration_utils import ConfigMixin @@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar SAMPLING_CONFIG_NAME = "scheduler_config.json" -class GaussianDDPMScheduler(nn.Module, ConfigMixin): +class GlideDDIMScheduler(nn.Module, ConfigMixin): config_name = SAMPLING_CONFIG_NAME def __init__( self, timesteps=1000, - beta_start=0.0001, - beta_end=0.02, beta_schedule="linear", - variance_type="fixed_small", + variance_type="fixed_large" ): super().__init__() self.register( timesteps=timesteps, - beta_start=beta_start, - beta_end=beta_end, beta_schedule=beta_schedule, - variance_type=variance_type, ) self.num_timesteps = int(timesteps) if beta_schedule == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / self.num_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) - elif beta_schedule == "squaredcos_cap_v2": - # GLIDE cosine schedule - betas = betas_for_alpha_bar( - timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): return torch.randn(shape, generator=generator).to(device) def __len__(self): - return self.num_timesteps + return self.num_timesteps \ No newline at end of file