mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] feat: MultiControlNet support for SDXL ControlNet pipeline (#4597)
* core: add multicontrolnet support to sdxl controlnet * modify checks. * fix: original_size determination * add: tests for multi controlnet sdxl. * remove unnecessary prints.
This commit is contained in:
@@ -39,6 +39,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
@@ -53,6 +54,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
class_labels=class_labels,
|
||||
timestep_cond=timestep_cond,
|
||||
attention_mask=attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
|
||||
@@ -149,7 +149,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetModel,
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
@@ -157,7 +157,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
super().__init__()
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
raise ValueError("MultiControlNet is not yet supported.")
|
||||
controlnet = MultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -530,6 +530,15 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
@@ -540,6 +549,25 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, MultiControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
||||
):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
|
||||
# When `image` is a nested list:
|
||||
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
||||
elif any(isinstance(i, list) for i in image):
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -551,14 +579,41 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
elif (
|
||||
isinstance(self.controlnet, MultiControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
||||
):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
||||
self.controlnet.nets
|
||||
):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
||||
" the same length as the number of controlnets"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
if not isinstance(control_guidance_end, (tuple, list)):
|
||||
control_guidance_end = [control_guidance_end]
|
||||
|
||||
if len(control_guidance_start) != len(control_guidance_end):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
||||
)
|
||||
|
||||
if isinstance(self.controlnet, MultiControlNetModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
for start, end in zip(control_guidance_start, control_guidance_end):
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
@@ -569,6 +624,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
@@ -606,6 +662,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
@@ -888,6 +945,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
@@ -933,6 +993,26 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = self.prepare_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
height, width = image[0].shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -963,12 +1043,15 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
|
||||
original_size = original_size or image.shape[-2:]
|
||||
target_size = target_size or (height, width)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
if isinstance(image, list):
|
||||
original_size = original_size or image[0].shape[-2:]
|
||||
else:
|
||||
original_size = original_size or image.shape[-2:]
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
|
||||
@@ -26,6 +26,7 @@ from diffusers import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.utils import randn_tensor, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
@@ -46,7 +47,7 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ControlNetPipelineSDXLFastTests(
|
||||
class StableDiffusionXLControlNetPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
@@ -297,3 +298,383 @@ class ControlNetPipelineSDXLFastTests(
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.normal(m.weight)
|
||||
m.bias.data.fill_(1.0)
|
||||
|
||||
controlnet1 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
controlnet1.controlnet_down_blocks.apply(init_weights)
|
||||
|
||||
torch.manual_seed(0)
|
||||
controlnet2 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
controlnet2.controlnet_down_blocks.apply(init_weights)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet1, controlnet2])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
|
||||
images = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_control_guidance_switch(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
scale = 10.0
|
||||
steps = 4
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_1 = pipe(**inputs)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
|
||||
|
||||
# make sure that all outputs are different
|
||||
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.normal(m.weight)
|
||||
m.bias.data.fill_(1.0)
|
||||
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
controlnet.controlnet_down_blocks.apply(init_weights)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
images = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_control_guidance_switch(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
scale = 10.0
|
||||
steps = 4
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_1 = pipe(**inputs)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_3 = pipe(
|
||||
**inputs,
|
||||
control_guidance_start=[0.1],
|
||||
control_guidance_end=[0.2],
|
||||
)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5])[0]
|
||||
|
||||
# make sure that all outputs are different
|
||||
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
Reference in New Issue
Block a user