1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add inferring_controlnet_cond_batch

This commit is contained in:
Takuma Mori
2023-05-04 14:45:48 +09:00
parent 364d59d13b
commit abe8d6311d

View File

@@ -658,8 +658,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
inferring_controlnet_cond_batch=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
@@ -696,7 +695,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
if not inferring_controlnet_cond_batch:
image = torch.cat([image] * 2)
return image
@@ -898,7 +897,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
# 3. Encode input prompt
# 3. Determination of whether to infer ControlNet using only for the conditional batch.
global_pool_conditions = False
if isinstance(self.controlnet, ControlNetModel):
global_pool_conditions = self.controlnet.config.global_pool_conditions
else:
... # TODO: Implement for MultiControlNetModel
inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance
# 4. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
@@ -909,7 +917,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare image
# 5. Prepare image
if isinstance(self.controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
@@ -919,8 +927,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []
@@ -934,8 +941,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
)
images.append(image_)
@@ -944,11 +950,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
else:
assert False
# 5. Prepare timesteps
# 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@@ -961,10 +967,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
latents,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -973,8 +979,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
if inferring_controlnet_cond_batch:
# Inferring ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
@@ -991,7 +997,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
if inferring_controlnet_cond_batch:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.