mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add inferring_controlnet_cond_batch
This commit is contained in:
@@ -658,8 +658,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
inferring_controlnet_cond_batch=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
@@ -696,7 +695,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
if not inferring_controlnet_cond_batch:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
@@ -898,7 +897,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
|
||||
|
||||
# 3. Encode input prompt
|
||||
# 3. Determination of whether to infer ControlNet using only for the conditional batch.
|
||||
global_pool_conditions = False
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
global_pool_conditions = self.controlnet.config.global_pool_conditions
|
||||
else:
|
||||
... # TODO: Implement for MultiControlNetModel
|
||||
|
||||
inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance
|
||||
|
||||
# 4. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
@@ -909,7 +917,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
# 5. Prepare image
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
@@ -919,8 +927,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
|
||||
)
|
||||
elif isinstance(self.controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
@@ -934,8 +941,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
@@ -944,11 +950,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
# 6. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# 7. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
@@ -961,10 +967,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -973,8 +979,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
if inferring_controlnet_cond_batch:
|
||||
# Inferring ControlNet only for the conditional batch.
|
||||
controlnet_latent_model_input = latents
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
@@ -991,7 +997,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
if inferring_controlnet_cond_batch:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
|
||||
Reference in New Issue
Block a user