mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[SD-XL] Fix sdxl controlnet inference (#4238)
* Fix controlnet xl inference * correct some sd xl control inference
This commit is contained in:
committed by
GitHub
parent
b288684d25
commit
3ba36f97b8
@@ -180,11 +180,19 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
model_sequence.extend([self.unet, self.vae])
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
@@ -639,7 +647,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
@@ -657,9 +665,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
original_size: Tuple[int, int] = (1024, 1024),
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = (1024, 1024),
|
||||
target_size: Tuple[int, int] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -875,6 +883,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
|
||||
original_size = original_size or image.shape[-2:]
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
|
||||
@@ -28,9 +28,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils import randn_tensor, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
@@ -125,10 +123,10 @@ class ControlNetPipelineSDXLFastTests(
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -179,6 +177,35 @@ class ControlNetPipelineSDXLFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_xl_multi_prompts(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user