mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Glide tensor format
This commit is contained in:
@@ -97,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
|
||||
|
||||
superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||
|
||||
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02)
|
||||
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt")
|
||||
|
||||
glide = GLIDE(
|
||||
text_unet=text2im_model,
|
||||
|
||||
@@ -30,7 +30,6 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -860,6 +859,9 @@ class GLIDE(DiffusionPipeline):
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
image = image[:1].permute(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
@@ -872,10 +874,10 @@ class GLIDE(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
(batch_size, self.upscale_unet.in_channels // 2, self.upscale_unet.resolution, self.upscale_unet.resolution),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
image = image.to(torch_device) * upsample_temp
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
@@ -887,10 +889,15 @@ class GLIDE(DiffusionPipeline):
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
model_output = self.upscale_unet(image, time_input, low_res)
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user