mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
adapt run.py
This commit is contained in:
110
run.py
110
run.py
@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model
|
||||
from models import layerspp
|
||||
from models import layers
|
||||
from models import normalization
|
||||
from models.ema import ExponentialMovingAverage
|
||||
from losses import get_optimizer
|
||||
|
||||
from utils import restore_checkpoint
|
||||
|
||||
@@ -27,6 +29,7 @@ import datasets
|
||||
import torch
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@@ -81,7 +84,6 @@ torch.manual_seed(0)
|
||||
|
||||
|
||||
class NewReverseDiffusionPredictor:
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor:
|
||||
|
||||
|
||||
class NewLangevinCorrector:
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
@@ -146,28 +147,19 @@ class NewLangevinCorrector:
|
||||
|
||||
|
||||
def save_image(x):
|
||||
# image_processed = x.cpu().permute(0, 2, 3, 1)
|
||||
# image_processed = (image_processed + 1.0) * 127.5
|
||||
# image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# 6. save image
|
||||
image_pil.save("../images/hey.png")
|
||||
|
||||
|
||||
#x = np.load("cifar10.npy")
|
||||
#
|
||||
#save_image(x)
|
||||
# @title Load the score-based model
|
||||
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
|
||||
if sde.lower() == 'vesde':
|
||||
from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||
# from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||
from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||
config = configs.get_config()
|
||||
config.model.num_scales = 1000
|
||||
config.model.num_scales = 2
|
||||
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
elif sde.lower() == 'vpsde':
|
||||
@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size
|
||||
|
||||
random_seed = 0 #@param {"type": "integer"}
|
||||
|
||||
score_model = mutils.create_model(config)
|
||||
#sigmas = mutils.get_sigmas(config)
|
||||
#scaler = datasets.get_data_scaler(config)
|
||||
#inverse_scaler = datasets.get_data_inverse_scaler(config)
|
||||
#score_model = mutils.create_model(config)
|
||||
#
|
||||
#optimizer = get_optimizer(config, score_model.parameters())
|
||||
#ema = ExponentialMovingAverage(score_model.parameters(),
|
||||
# decay=config.model.ema_rate)
|
||||
#state = dict(step=0, optimizer=optimizer,
|
||||
# model=score_model, ema=ema)
|
||||
#
|
||||
#state = restore_checkpoint(ckpt_filename, state, config.device)
|
||||
#ema.copy_to(score_model.parameters())
|
||||
|
||||
loaded_state = torch.load(ckpt_filename)
|
||||
score_model.load_state_dict(loaded_state["model"], strict=False)
|
||||
#score_model = mutils.create_model(config)
|
||||
|
||||
from diffusers import NCSNpp
|
||||
score_model = NCSNpp(config).to(config.device)
|
||||
score_model = torch.nn.DataParallel(score_model)
|
||||
|
||||
loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt")
|
||||
del loaded_state["module.sigmas"]
|
||||
score_model.load_state_dict(loaded_state, strict=False)
|
||||
|
||||
inverse_scaler = datasets.get_data_inverse_scaler(config)
|
||||
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
|
||||
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
|
||||
|
||||
def image_grid(x):
|
||||
size = config.data.image_size
|
||||
channels = config.data.num_channels
|
||||
img = x.reshape(-1, size, size, channels)
|
||||
w = int(np.sqrt(img.shape[0]))
|
||||
img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
|
||||
return img
|
||||
|
||||
#@title PC sampling
|
||||
img_size = config.data.image_size
|
||||
channels = config.data.num_channels
|
||||
shape = (batch_size, channels, img_size, img_size)
|
||||
probability_flow = False
|
||||
snr = 0.16 #@param {"type": "number"}
|
||||
snr = 0.15 #@param {"type": "number"}
|
||||
n_steps = 1#@param {"type": "integer"}
|
||||
|
||||
|
||||
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
|
||||
# inverse_scaler, snr, n_steps=n_steps,
|
||||
# probability_flow=probability_flow,
|
||||
# continuous=config.training.continuous,
|
||||
# eps=sampling_eps, device=config.device)
|
||||
#
|
||||
#x, n = sampling_fn(score_model)
|
||||
#save_image(x)
|
||||
|
||||
|
||||
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
|
||||
"""A wrapper that configures and returns the update function of predictors."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
||||
@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
||||
snr=snr,
|
||||
n_steps=n_steps)
|
||||
|
||||
device = "cuda"
|
||||
model = score_model.to(device)
|
||||
denoise = False
|
||||
device = config.device
|
||||
model = score_model
|
||||
denoise = True
|
||||
|
||||
new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps)
|
||||
new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model)
|
||||
|
||||
|
||||
#
|
||||
with torch.no_grad():
|
||||
# Initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
@@ -269,21 +282,32 @@ with torch.no_grad():
|
||||
for i in range(sde.N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
# x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
# x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
|
||||
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
|
||||
|
||||
|
||||
save_image(x)
|
||||
|
||||
# for 5
|
||||
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
#save_image(x)
|
||||
|
||||
# for 1000
|
||||
assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
# for 5 cifar10
|
||||
x_sum = 106071.9922
|
||||
x_mean = 34.52864456176758
|
||||
|
||||
# for 1000 cifar10
|
||||
x_sum = 461.9700
|
||||
x_mean = 0.1504
|
||||
|
||||
# for 2 for 1024
|
||||
x_sum = 3382810112.0
|
||||
x_mean = 1075.366455078125
|
||||
|
||||
def check_x_sum_x_mean(x, x_sum, x_mean):
|
||||
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
|
||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
||||
|
||||
Reference in New Issue
Block a user