From 07ef4855cd2b2fe9d72bc0479f15959333f11068 Mon Sep 17 00:00:00 2001 From: takuoko Date: Tue, 30 May 2023 20:38:16 +0900 Subject: [PATCH] [Community, Enhancement] Add reference tricks in README (#3589) add reference tricks --- examples/community/README.md | 5 +++++ .../stable_diffusion_controlnet_reference.py | 16 ++++++++-------- examples/community/stable_diffusion_reference.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index f3af034100..21fba38e69 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1326,6 +1326,8 @@ image.save('tensorrt_img2img_new_zealand_hills.png') This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). +Based on [this issue](https://github.com/huggingface/diffusers/issues/3566), +- `EulerAncestralDiscreteScheduler` got poor results. ```py import torch @@ -1369,6 +1371,9 @@ Output Image of `reference_attn=True` and `reference_adain=True` This pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). +Based on [this issue](https://github.com/huggingface/diffusers/issues/3566), +- `EulerAncestralDiscreteScheduler` got poor results. +- `guess_mode=True` works well for ControlNet v1.1 ```py import cv2 diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index 606fe09c68..ca06136d78 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -505,8 +505,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -545,8 +545,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -605,8 +605,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -642,8 +642,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 22e0b40f60..dbfb768f8b 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -499,8 +499,8 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -539,8 +539,8 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -599,8 +599,8 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) @@ -636,8 +636,8 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) - self.mean_bank.append(mean) - self.var_bank.append(var) + self.mean_bank.append([mean]) + self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)