1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
anton-l
2022-06-08 11:53:12 +02:00
parent bb98a5b709
commit 07ffe73f79
11 changed files with 91 additions and 96 deletions

View File

@@ -1,9 +1,10 @@
import torch
from torch import nn
from transformers import CLIPTextConfig, GPT2Tokenizer
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
@@ -22,7 +23,7 @@ config = CLIPTextConfig(
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
#tokenizer.save_pretrained("./glide-base")
# tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model
@@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#with torch.no_grad():
# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
# with torch.no_grad():
# outputs = model(**inputs)
#model.save_pretrained("./glide-base")
# model.save_pretrained("./glide-base")
### Convert the UNet
@@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
glide.save_pretrained("./glide-base")
glide.save_pretrained("./glide-base")

View File

@@ -14,12 +14,12 @@
# limitations under the License.
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from transformers import GPT2Tokenizer
import numpy as np
import torch
import tqdm
import torch
import numpy as np
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
from transformers import GPT2Tokenizer
def _extract_into_tensor(arr, timesteps, broadcast_shape):
@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline):
def __init__(
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
@@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline):
self.text_encoder.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
@@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
noise = self.noise_scheduler.sample_noise(image.shape)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
) # no noise when t == 0
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
return image

View File

@@ -1,6 +1,8 @@
import torch
from modeling_glide import GLIDE
generator = torch.Generator()
generator = generator.manual_seed(0)