From ebc99a77aad647c5d33eb36a33c23f7b3949cb40 Mon Sep 17 00:00:00 2001 From: btlorch Date: Fri, 26 Apr 2024 02:44:53 +0200 Subject: [PATCH] Convert RGB to BGR for the SDXL watermark encoder (#7013) * Convert channel order to BGR for the watermark encoder. Convert the watermarked BGR images back to RGB. Fixes #6292 * Revert channel order before stacking images to overcome limitations that negative strides are currently not supported --------- Co-authored-by: Sayak Paul --- .../pipelines/stable_diffusion_xl/watermark.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index 5b6e36d9f4..f457cdbdb1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -28,9 +28,15 @@ class StableDiffusionXLWatermarker: images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() - images = [self.encoder.encode(image, "dwtDct") for image in images] + # Convert RGB to BGR, which is the channel order expected by the watermark encoder. + images = images[:, :, :, ::-1] - images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) + # Add watermark and convert BGR back to RGB + images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images] + + images = np.array(images) + + images = torch.from_numpy(images).permute(0, 3, 1, 2) images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) return images