diff --git a/run.py b/run.py index 61e29603fb..b2ec6eea29 100755 --- a/run.py +++ b/run.py @@ -269,20 +269,21 @@ with torch.no_grad(): for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t -# x, x_mean = corrector_update_fn(x, vec_t, model=model) -# x, x_mean = predictor_update_fn(x, vec_t, model=model) - x, x_mean = new_corrector.update_fn(x, vec_t) - x, x_mean = new_predictor.update_fn(x, vec_t) + x, x_mean = corrector_update_fn(x, vec_t, model=model) + x, x_mean = predictor_update_fn(x, vec_t, model=model) +# x, x_mean = new_corrector.update_fn(x, vec_t) +# x, x_mean = new_predictor.update_fn(x, vec_t) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) +save_image(x) + # for 5 -#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong" -#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong" +#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" # for 1000 -assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong" -assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong" +assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -save_image(x)