mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix bug in ControlNetPipelines with MultiControlNetModel of length 1 (#4032)
* Fix bug in ControlNetPipelines with MultiControlNetModel of length 1 * Add tests for varying number of ControlNet models * Fix missing indexing for control_guidance_start and control_guidance_end * Fix code quality * Separate test for MultiControlNet with one model * Revert formatting of earlier test
This commit is contained in:
@@ -914,7 +914,7 @@ class StableDiffusionControlNetPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -1007,7 +1007,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -1242,7 +1242,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -398,6 +398,179 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.normal(m.weight)
|
||||
m.bias.data.fill_(1.0)
|
||||
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
controlnet.controlnet_down_blocks.apply(init_weights)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
|
||||
images = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_control_guidance_switch(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
scale = 10.0
|
||||
steps = 4
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_1 = pipe(**inputs)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_3 = pipe(
|
||||
**inputs,
|
||||
control_guidance_start=[0.1],
|
||||
control_guidance_end=[0.2],
|
||||
)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5])[0]
|
||||
|
||||
# make sure that all outputs are different
|
||||
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
try:
|
||||
# save_pretrained is not implemented for Multi-ControlNet
|
||||
pipe.save_pretrained(tmpdir)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user