From 78e99a997bb29bbaa7b91fa0ff233e46bee95e9c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 18:48:26 +0000 Subject: [PATCH] adapt run.py --- run.py | 110 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 67 insertions(+), 43 deletions(-) diff --git a/run.py b/run.py index b2ec6eea29..7a55acbab2 100755 --- a/run.py +++ b/run.py @@ -11,6 +11,8 @@ from models import ddpm as ddpm_model from models import layerspp from models import layers from models import normalization +from models.ema import ExponentialMovingAverage +from losses import get_optimizer from utils import restore_checkpoint @@ -27,6 +29,7 @@ import datasets import torch +torch.backends.cuda.matmul.allow_tf32 = False torch.manual_seed(0) @@ -81,7 +84,6 @@ torch.manual_seed(0) class NewReverseDiffusionPredictor: - def __init__(self, sde, score_fn, probability_flow=False): super().__init__() self.sde = sde @@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor: class NewLangevinCorrector: - def __init__(self, sde, score_fn, snr, n_steps): super().__init__() self.sde = sde @@ -146,28 +147,19 @@ class NewLangevinCorrector: def save_image(x): -# image_processed = x.cpu().permute(0, 2, 3, 1) -# image_processed = (image_processed + 1.0) * 127.5 -# image_processed = image_processed.numpy().astype(np.uint8) image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) - - # 6. save image image_pil.save("../images/hey.png") -#x = np.load("cifar10.npy") -# -#save_image(x) -# @title Load the score-based model sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} if sde.lower() == 'vesde': - from configs.ve import cifar10_ncsnpp_continuous as configs - ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# from configs.ve import ffhq_ncsnpp_continuous as configs -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# from configs.ve import cifar10_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" + from configs.ve import ffhq_ncsnpp_continuous as configs + ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" config = configs.get_config() - config.model.num_scales = 1000 + config.model.num_scales = 2 sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 elif sde.lower() == 'vpsde': @@ -189,32 +181,53 @@ config.eval.batch_size = batch_size random_seed = 0 #@param {"type": "integer"} -score_model = mutils.create_model(config) +#sigmas = mutils.get_sigmas(config) +#scaler = datasets.get_data_scaler(config) +#inverse_scaler = datasets.get_data_inverse_scaler(config) +#score_model = mutils.create_model(config) +# +#optimizer = get_optimizer(config, score_model.parameters()) +#ema = ExponentialMovingAverage(score_model.parameters(), +# decay=config.model.ema_rate) +#state = dict(step=0, optimizer=optimizer, +# model=score_model, ema=ema) +# +#state = restore_checkpoint(ckpt_filename, state, config.device) +#ema.copy_to(score_model.parameters()) -loaded_state = torch.load(ckpt_filename) -score_model.load_state_dict(loaded_state["model"], strict=False) +#score_model = mutils.create_model(config) + +from diffusers import NCSNpp +score_model = NCSNpp(config).to(config.device) +score_model = torch.nn.DataParallel(score_model) + +loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt") +del loaded_state["module.sigmas"] +score_model.load_state_dict(loaded_state, strict=False) inverse_scaler = datasets.get_data_inverse_scaler(config) predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} -def image_grid(x): - size = config.data.image_size - channels = config.data.num_channels - img = x.reshape(-1, size, size, channels) - w = int(np.sqrt(img.shape[0])) - img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) - return img - #@title PC sampling img_size = config.data.image_size channels = config.data.num_channels shape = (batch_size, channels, img_size, img_size) probability_flow = False -snr = 0.16 #@param {"type": "number"} +snr = 0.15 #@param {"type": "number"} n_steps = 1#@param {"type": "integer"} +#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector, +# inverse_scaler, snr, n_steps=n_steps, +# probability_flow=probability_flow, +# continuous=config.training.continuous, +# eps=sampling_eps, device=config.device) +# +#x, n = sampling_fn(score_model) +#save_image(x) + + def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): """A wrapper that configures and returns the update function of predictors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) @@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn, snr=snr, n_steps=n_steps) -device = "cuda" -model = score_model.to(device) -denoise = False +device = config.device +model = score_model +denoise = True new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) - +# with torch.no_grad(): # Initial sample x = sde.prior_sampling(shape).to(device) @@ -269,21 +282,32 @@ with torch.no_grad(): for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t - x, x_mean = corrector_update_fn(x, vec_t, model=model) - x, x_mean = predictor_update_fn(x, vec_t, model=model) -# x, x_mean = new_corrector.update_fn(x, vec_t) -# x, x_mean = new_predictor.update_fn(x, vec_t) +# x, x_mean = corrector_update_fn(x, vec_t, model=model) +# x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = new_corrector.update_fn(x, vec_t) + x, x_mean = new_predictor.update_fn(x, vec_t) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) -save_image(x) -# for 5 -#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +#save_image(x) -# for 1000 -assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +# for 5 cifar10 +x_sum = 106071.9922 +x_mean = 34.52864456176758 +# for 1000 cifar10 +x_sum = 461.9700 +x_mean = 0.1504 + +# for 2 for 1024 +x_sum = 3382810112.0 +x_mean = 1075.366455078125 + +def check_x_sum_x_mean(x, x_sum, x_mean): + assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" + assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" + + +check_x_sum_x_mean(x, x_sum, x_mean)