1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge remote-tracking branch 'origin/main'

This commit is contained in:
anton-l
2022-06-14 12:37:38 +02:00
2 changed files with 62 additions and 95 deletions

View File

@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
)
image = image.to(torch_device)
seq = list(inference_step_times)
seq_next = [-1] + list(seq[:-1])
model = self.unet
warmup_time_steps = list(reversed([(t + 5) // 10 * 10 for t in range(seq[-4], seq[-1], 5)]))
cur_residual = 0
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
prev_image = image
ets = []
for i in range(len(warmup_time_steps)):
t = warmup_time_steps[i] * torch.ones(image.shape[0])
t_next = (warmup_time_steps[i + 1] if i < len(warmup_time_steps) - 1 else warmup_time_steps[-1]) * torch.ones(image.shape[0])
for t in tqdm.tqdm(range(len(warmup_time_steps))):
t_orig = warmup_time_steps[t]
residual = self.unet(image, t_orig)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
if i % 4 == 0:
cur_residual += 1 / 6 * residual
ets.append(residual)
if t % 4 == 0:
prev_image = image
elif (i - 1) % 4 == 0:
cur_residual += 1 / 3 * residual
elif (i - 2) % 4 == 0:
cur_residual += 1 / 3 * residual
elif (i - 3) % 4 == 0:
cur_residual += 1 / 6 * residual
residual = cur_residual
cur_residual = 0
image = image.to("cpu")
t_2 = warmup_time_steps[4 * (i // 4)] * torch.ones(image.shape[0])
image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_2, t_next, residual)
image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps)
step_idx = len(seq) - 4
while step_idx >= 0:
i = seq[step_idx]
j = seq_next[step_idx]
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t]
residual = self.unet(image, t_orig)
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
ets.append(residual)
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = img_next
step_idx = step_idx - 1
image = self.noise_scheduler.step(residual, image, t, num_inference_steps)
return image

View File

@@ -55,22 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# for now we only support F-PNDM, i.e. the runge-kutta method
self.pndm_order = 4
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# running values
self.cur_residual = 0
self.ets = []
self.warmup_time_steps = {}
self.time_steps = {}
def get_alpha(self, time_step):
return self.alphas[time_step]
@@ -83,51 +75,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.one
return self.alphas_cumprod[time_step]
def step(self, img, t_start, t_end, model, ets):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
t_next, t = t_start, t_end
def get_warmup_time_steps(self, num_inference_steps):
if num_inference_steps in self.warmup_time_steps:
return self.warmup_time_steps[num_inference_steps]
noise_ = model(img.to("cuda"), t.to("cuda"))
noise_ = noise_.to("cpu")
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
t_list = [t, (t+t_next)/2, t_next]
if len(ets) > 2:
ets.append(noise_)
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
else:
noise = self.runge_kutta(img, t_list, model, ets, noise_)
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
return img_next, ets
return self.warmup_time_steps[num_inference_steps]
def runge_kutta(self, x, t_list, model, ets, noise_):
model = model.to("cuda")
x = x.to("cpu")
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
e_1 = noise_
ets.append(e_1)
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
e_2 = e_2.to("cpu")
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
return self.time_steps[num_inference_steps]
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
e_3 = e_3.to("cpu")
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
def step_warm_up(self, residual, image, t, num_inference_steps):
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
e_4 = e_4.to("cpu")
t_prev = warmup_time_steps[t // 4 * 4]
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4)
if t % 4 == 0:
self.cur_residual += 1 / 6 * residual
self.ets.append(residual)
elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 2) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 3) % 4 == 0:
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
return et
return self.transfer(image, t_prev, t_next, residual)
def step(self, residual, image, t, num_inference_steps):
timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t]
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(image, t_prev, t_next, residual)
def transfer(self, x, t, t_next, et):
alphas_cump = self.alphas_cumprod
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
# TODO(Patrick): clean up to be compatible with numpy and give better names
alphas_cump = self.alphas_cumprod.to(x.device)
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)