From 0dbc4779c8bf396d48170dda52befc83288e109f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 1 Jul 2022 12:50:34 +0200 Subject: [PATCH] add centered back --- src/diffusers/models/unet_sde_score_estimation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d9a4732f0b..1c2a2d10ff 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -229,6 +229,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self, image_size=1024, num_channels=3, + centered=False, attn_resolutions=(16,), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), conditional=True, @@ -253,6 +254,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.register_to_config( image_size=image_size, num_channels=num_channels, + centered=centered, attn_resolutions=attn_resolutions, ch_mult=ch_mult, conditional=conditional, @@ -457,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin): temb = None # If input data is in [0, 1] - x = 2 * x - 1.0 + if not self.config.centered: + x = 2 * x - 1.0 # Downsampling block input_pyramid = None