diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 3bd7f82d7e..46c229e825 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -658,8 +658,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade num_images_per_prompt, device, dtype, - do_classifier_free_guidance=False, - guess_mode=False, + inferring_controlnet_cond_batch=False, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): @@ -696,7 +695,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance and not guess_mode: + if not inferring_controlnet_cond_batch: image = torch.cat([image] * 2) return image @@ -898,7 +897,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) - # 3. Encode input prompt + # 3. Determination of whether to infer ControlNet using only for the conditional batch. + global_pool_conditions = False + if isinstance(self.controlnet, ControlNetModel): + global_pool_conditions = self.controlnet.config.global_pool_conditions + else: + ... # TODO: Implement for MultiControlNetModel + + inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance + + # 4. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, @@ -909,7 +917,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Prepare image + # 5. Prepare image if isinstance(self.controlnet, ControlNetModel): image = self.prepare_image( image=image, @@ -919,8 +927,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, + inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, ) elif isinstance(self.controlnet, MultiControlNetModel): images = [] @@ -934,8 +941,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, + inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, ) images.append(image_) @@ -944,11 +950,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade else: assert False - # 5. Prepare timesteps + # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 6. Prepare latent variables + # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -961,10 +967,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade latents, ) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 8. Denoising loop + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -973,8 +979,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. + if inferring_controlnet_cond_batch: + # Inferring ControlNet only for the conditional batch. controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: @@ -991,7 +997,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if inferring_controlnet_cond_batch: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged.