mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
seq = list(inference_step_times)
|
||||
seq_next = [-1] + list(seq[:-1])
|
||||
model = self.unet
|
||||
|
||||
warmup_time_steps = list(reversed([(t + 5) // 10 * 10 for t in range(seq[-4], seq[-1], 5)]))
|
||||
|
||||
cur_residual = 0
|
||||
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
||||
prev_image = image
|
||||
ets = []
|
||||
for i in range(len(warmup_time_steps)):
|
||||
t = warmup_time_steps[i] * torch.ones(image.shape[0])
|
||||
t_next = (warmup_time_steps[i + 1] if i < len(warmup_time_steps) - 1 else warmup_time_steps[-1]) * torch.ones(image.shape[0])
|
||||
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
||||
t_orig = warmup_time_steps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
residual = model(image.to("cuda"), t.to("cuda"))
|
||||
residual = residual.to("cpu")
|
||||
|
||||
if i % 4 == 0:
|
||||
cur_residual += 1 / 6 * residual
|
||||
ets.append(residual)
|
||||
if t % 4 == 0:
|
||||
prev_image = image
|
||||
elif (i - 1) % 4 == 0:
|
||||
cur_residual += 1 / 3 * residual
|
||||
elif (i - 2) % 4 == 0:
|
||||
cur_residual += 1 / 3 * residual
|
||||
elif (i - 3) % 4 == 0:
|
||||
cur_residual += 1 / 6 * residual
|
||||
residual = cur_residual
|
||||
cur_residual = 0
|
||||
|
||||
image = image.to("cpu")
|
||||
t_2 = warmup_time_steps[4 * (i // 4)] * torch.ones(image.shape[0])
|
||||
image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_2, t_next, residual)
|
||||
image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps)
|
||||
|
||||
step_idx = len(seq) - 4
|
||||
while step_idx >= 0:
|
||||
i = seq[step_idx]
|
||||
j = seq_next[step_idx]
|
||||
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(timesteps))):
|
||||
t_orig = timesteps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
t = (torch.ones(image.shape[0]) * i)
|
||||
t_next = (torch.ones(image.shape[0]) * j)
|
||||
|
||||
residual = model(image.to("cuda"), t.to("cuda"))
|
||||
residual = residual.to("cpu")
|
||||
ets.append(residual)
|
||||
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||
|
||||
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
||||
image = img_next
|
||||
|
||||
step_idx = step_idx - 1
|
||||
image = self.noise_scheduler.step(residual, image, t, num_inference_steps)
|
||||
|
||||
return image
|
||||
|
||||
@@ -55,22 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# self.register_buffer("betas", betas.to(torch.float32))
|
||||
# self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
# for now we only support F-PNDM, i.e. the runge-kutta method
|
||||
self.pndm_order = 4
|
||||
|
||||
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
# TODO(PVP) - check how much of these is actually necessary!
|
||||
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
|
||||
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
|
||||
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# if variance_type == "fixed_small":
|
||||
# log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
# elif variance_type == "fixed_large":
|
||||
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
#
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
# running values
|
||||
self.cur_residual = 0
|
||||
self.ets = []
|
||||
self.warmup_time_steps = {}
|
||||
self.time_steps = {}
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
@@ -83,51 +75,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def step(self, img, t_start, t_end, model, ets):
|
||||
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
|
||||
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
|
||||
t_next, t = t_start, t_end
|
||||
def get_warmup_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.warmup_time_steps:
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
|
||||
noise_ = model(img.to("cuda"), t.to("cuda"))
|
||||
noise_ = noise_.to("cpu")
|
||||
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
|
||||
|
||||
t_list = [t, (t+t_next)/2, t_next]
|
||||
if len(ets) > 2:
|
||||
ets.append(noise_)
|
||||
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||
else:
|
||||
noise = self.runge_kutta(img, t_list, model, ets, noise_)
|
||||
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
|
||||
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
|
||||
|
||||
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
|
||||
return img_next, ets
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
|
||||
def runge_kutta(self, x, t_list, model, ets, noise_):
|
||||
model = model.to("cuda")
|
||||
x = x.to("cpu")
|
||||
def get_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.time_steps:
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
e_1 = noise_
|
||||
ets.append(e_1)
|
||||
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
|
||||
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
|
||||
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
|
||||
|
||||
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
|
||||
e_2 = e_2.to("cpu")
|
||||
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
|
||||
e_3 = e_3.to("cpu")
|
||||
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
|
||||
def step_warm_up(self, residual, image, t, num_inference_steps):
|
||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||
|
||||
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
|
||||
e_4 = e_4.to("cpu")
|
||||
t_prev = warmup_time_steps[t // 4 * 4]
|
||||
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
|
||||
|
||||
et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4)
|
||||
if t % 4 == 0:
|
||||
self.cur_residual += 1 / 6 * residual
|
||||
self.ets.append(residual)
|
||||
elif (t - 1) % 4 == 0:
|
||||
self.cur_residual += 1 / 3 * residual
|
||||
elif (t - 2) % 4 == 0:
|
||||
self.cur_residual += 1 / 3 * residual
|
||||
elif (t - 3) % 4 == 0:
|
||||
residual = self.cur_residual + 1 / 6 * residual
|
||||
self.cur_residual = 0
|
||||
|
||||
return et
|
||||
return self.transfer(image, t_prev, t_next, residual)
|
||||
|
||||
def step(self, residual, image, t, num_inference_steps):
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
|
||||
t_prev = timesteps[t]
|
||||
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
self.ets.append(residual)
|
||||
|
||||
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return self.transfer(image, t_prev, t_next, residual)
|
||||
|
||||
def transfer(self, x, t, t_next, et):
|
||||
alphas_cump = self.alphas_cumprod
|
||||
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
|
||||
# TODO(Patrick): clean up to be compatible with numpy and give better names
|
||||
|
||||
alphas_cump = self.alphas_cumprod.to(x.device)
|
||||
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
|
||||
|
||||
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user