mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
@@ -34,49 +34,68 @@ class DDIM(DiffusionPipeline):
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# get actual t and t-1
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
pred_noise_t = self.unet(image, inference_step_times[t])
|
||||
|
||||
# 2. get actual t and t-1
|
||||
train_step = inference_step_times[t]
|
||||
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||
|
||||
# compute alphas
|
||||
# 3. compute alphas, betas
|
||||
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
|
||||
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
|
||||
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
|
||||
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
|
||||
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
||||
beta_prod_t = (1 - alpha_prod_t)
|
||||
beta_prod_t_prev = (1 - alpha_prod_t_prev)
|
||||
|
||||
# compute relevant coefficients
|
||||
coeff_1 = (
|
||||
(alpha_prod_t_prev - alpha_prod_t).sqrt()
|
||||
* alpha_prod_t_prev_rsqrt
|
||||
* beta_prod_t_prev_sqrt
|
||||
/ beta_prod_t_sqrt
|
||||
* eta
|
||||
)
|
||||
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
|
||||
# 4. Compute predicted previous image from predicted noise
|
||||
|
||||
# model forward
|
||||
with torch.no_grad():
|
||||
noise_residual = self.unet(image, train_step)
|
||||
# First: compute predicted original image from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
|
||||
|
||||
# predict mean of prev image
|
||||
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
|
||||
# Second: Clip "predicted x_0"
|
||||
pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
|
||||
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
|
||||
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
|
||||
std_dev_t = eta * std_dev_t
|
||||
|
||||
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
|
||||
|
||||
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
|
||||
|
||||
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
|
||||
# Note: eta = 1.0 essentially corresponds to DDPM
|
||||
if eta > 0.0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
image = pred_mean + coeff_1 * noise
|
||||
prev_image = pred_prev_image + std_dev_t * noise
|
||||
else:
|
||||
image = pred_mean
|
||||
prev_image = pred_prev_image
|
||||
|
||||
# 6. Set current image to prev_image: x_t -> x_t-1
|
||||
image = prev_image
|
||||
|
||||
return image
|
||||
|
||||
@@ -30,43 +30,43 @@ class DDPM(DiffusionPipeline):
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
# 1. Sample gaussian noise
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (
|
||||
(1 - self.noise_scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(self.noise_scheduler.get_alpha(t))
|
||||
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
|
||||
* self.noise_scheduler.get_beta(t)
|
||||
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = self.unet(image, t)
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# iv) sample variance
|
||||
# 3. compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
# First: Compute inner formula
|
||||
pred_mean = (1 / alpha_prod_t.sqrt()) * (image - beta_prod_t.sqrt() * noise_residual)
|
||||
# Second: Clip
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
# Third: Compute outer coefficients
|
||||
pred_mean_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
|
||||
image_coeff = (beta_prod_t_prev * self.noise_scheduler.get_alpha(t).sqrt()) / beta_prod_t
|
||||
# Fourth: Compute outer formula
|
||||
prev_image = pred_mean_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# 4. sample variance
|
||||
prev_variance = self.noise_scheduler.sample_variance(
|
||||
t, prev_image.shape, device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
# 5. sample x_{t-1} ~ N(prev_image, prev_variance) = add variance to predicted image
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ LOADABLE_CLASSES = {
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,24 +83,25 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module")
|
||||
|
||||
for name, (library_name, class_name) in model_index_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
# TODO: Suraj
|
||||
if library_name == self.__module__:
|
||||
library_name = self
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
save_method_name = importable_classes[class_name][0]
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class)
|
||||
if issubclass(model_cls, class_candidate):
|
||||
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
save_method = getattr(getattr(self, name), save_method_name)
|
||||
save_method(os.path.join(save_directory, name))
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user