From 2df84a57da3080347b6203fdad3742f17ca70426 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 2 Dec 2022 14:25:07 +0100 Subject: [PATCH] delta border --- src/diffusers/models/vae.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b537d90861..931d38af3c 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -616,6 +616,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return DecoderOutput(sample=decoded) + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) weighting = torch.clip(