From aa3c46d99acfaa145bdf620f821de9b409c2e6c6 Mon Sep 17 00:00:00 2001 From: v2ray <60914079+LagPixelLOL@users.noreply.github.com> Date: Thu, 26 Sep 2024 06:26:58 +0800 Subject: [PATCH] [Doc] Improved level of clarity for latents_to_rgb. (#9529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed latents_to_rgb doc. Co-authored-by: Álvaro Somoza --- docs/source/en/using-diffusers/callback.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index d4d23d6254..68c621ffc5 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -171,14 +171,13 @@ def latents_to_rgb(latents): weights = ( (60, -60, 25, -70), (60, -5, 15, -50), - (60, 10, -5, -35) + (60, 10, -5, -35), ) weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1) - image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy() - image_array = image_array.transpose(1, 2, 0) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) ``` @@ -189,7 +188,7 @@ def latents_to_rgb(latents): def decode_tensors(pipe, step, timestep, callback_kwargs): latents = callback_kwargs["latents"] - image = latents_to_rgb(latents) + image = latents_to_rgb(latents[0]) image.save(f"{step}.png") return callback_kwargs