diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index a801640634..10369fca60 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -22,8 +22,7 @@ config = CLIPTextConfig( use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() -tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") -# tokenizer.save_pretrained("./glide-base") +tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") hf_encoder = model.text_model @@ -52,12 +51,6 @@ for layer_idx in range(config.num_hidden_layers): hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] -# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") -# with torch.no_grad(): -# outputs = model(**inputs) - -# model.save_pretrained("./glide-base") - ### Convert the UNet unet_model = UNetGLIDEModel( @@ -73,6 +66,7 @@ unet_model = UNetGLIDEModel( num_heads_upsample=1, use_scale_shift_norm=True, resblock_updown=True, + transformer_dim=512, ) unet_model.load_state_dict(state_dict, strict=False) diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index ecd2963785..cc2880d85d 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -130,21 +130,37 @@ class GLIDE(DiffusionPipeline): self.unet.to(torch_device) self.text_encoder.to(torch_device) + # Create a classifier-free guidance sampling function + guidance_scale = 3.0 + + def model_fn(x_t, ts, transformer_out, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.unet(combined, ts, transformer_out, **kwargs) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + # 1. Sample gaussian noise + batch_size = 2 # second image is empty for classifier-free guidance image = self.noise_scheduler.sample_noise( - (1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator + (batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator ) # 2. Encode tokens # an empty input is needed to guide the model away from ( inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") - transformer_out = self.text_encoder(**inputs).last_hidden_state + input_ids = inputs["input_ids"].to(torch_device) + attention_mask = inputs["attention_mask"].to(torch_device) + transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state num_timesteps = len(self.noise_scheduler) for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): t = torch.tensor([i] * image.shape[0], device=torch_device) - mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) - noise = self.noise_scheduler.sample_noise(image.shape) + mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out) + noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 2c3eafd29c..1bea36fc85 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -9,6 +9,6 @@ generator = generator.manual_seed(0) # 1. Load models pipeline = GLIDE.from_pretrained("fusing/glide-base") -img = pipeline(generator) +img = pipeline("an oil painting of a corgi", generator) print(img) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 4764dbf7e5..97f9b56ea4 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, - encoder_channels=None, + transformer_dim=512, ): super().__init__() self.register( @@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, - encoder_channels=encoder_channels, + transformer_dim=transformer_dim, ) if num_heads_upsample == -1: @@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): linear(time_embed_dim, time_embed_dim), ) + self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) + ch = input_ch = int(channel_mult[0] * model_channels) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self._feature_size = ch @@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ), ResBlock( ch, @@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ) ) if level and i == num_res_blocks: @@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) @@ -655,13 +653,13 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): emb = emb + transformer_proj.to(emb) - h = x.type(self.dtype) + h = x for module in self.input_blocks: h = module(h, emb, transformer_out) hs.append(h) h = self.middle_block(h, emb, transformer_out) for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) + other = hs.pop() + h = torch.cat([h, other], dim=1) h = module(h, emb, transformer_out) - h = h.type(x.dtype) return self.out(h) diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py index 12ec76a221..2cd8152144 100644 --- a/src/diffusers/schedulers/classifier_free_guidance.py +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -65,14 +65,14 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): if beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - betas = betas_for_alpha_bar( + self.betas = betas_for_alpha_bar( timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - alphas = 1.0 - betas + alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) @@ -81,12 +81,12 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) def sample_noise(self, shape, device, generator=None):