diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 4048d70484..b3508307d4 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -92,6 +92,10 @@ class ClassifierFreeGuidance(BaseGuidance): return pred + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + @property def num_conditions(self) -> int: num_conditions = 1 @@ -100,6 +104,8 @@ class ClassifierFreeGuidance(BaseGuidance): return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index ecde7334b2..690afae891 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -39,6 +39,7 @@ class BaseGuidance: self._timestep: torch.LongTensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 + self._enabled = True if not (0.0 <= start < 1.0): raise ValueError( @@ -54,6 +55,12 @@ class BaseGuidance: "`_input_predictions` must be a list of required prediction names for the guidance technique." ) + def force_disable(self): + self._enabled = False + + def force_enable(self): + self._enabled = True + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps @@ -62,10 +69,10 @@ class BaseGuidance: self._num_outputs_prepared = 0 def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.") def __call__(self, **kwargs) -> Any: if len(kwargs) != self.num_conditions: @@ -75,11 +82,19 @@ class BaseGuidance: return self.forward(**kwargs) def forward(self, *args, **kwargs) -> Any: - raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + @property def num_conditions(self) -> int: - raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") @property def outputs(self) -> Dict[str, torch.Tensor]: @@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg """ Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the - `GuidanceMixin` class. + `BaseGuidance` class. Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: - The first element is the conditional input. diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index e20a700fee..92ae7f8518 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -189,6 +189,10 @@ class SkipLayerGuidance(BaseGuidance): pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 @property def num_conditions(self) -> int: @@ -200,6 +204,8 @@ class SkipLayerGuidance(BaseGuidance): return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step @@ -211,6 +217,8 @@ class SkipLayerGuidance(BaseGuidance): return is_within_range and not is_close def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5f125605a2..5df57c6c16 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2263,21 +2263,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): ) for batch_index, ( - latents_i, - prompt_embeds_i, - add_time_ids_i, - pooled_prompt_embeds_i, - mask_i, - masked_image_latents_i, - ip_adapter_embeds_i, + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i, ) in enumerate(zip( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): latents_i = pipeline.scheduler.scale_model_input(latents_i, t) @@ -2285,6 +2273,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): if data.num_channels_unet == 9: latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + # Prepare additional conditionings data.added_cond_kwargs = { "text_embeds": pooled_prompt_embeds_i, "time_ids": add_time_ids_i, @@ -2292,7 +2281,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): if ip_adapter_embeds_i is not None: data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i - # predict the noise residual + # Predict the noise residual data.noise_pred = pipeline.unet( latents_i, t, @@ -2347,7 +2336,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -2363,8 +2351,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), @@ -2515,8 +2501,8 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] - if image_batch_size == 1: repeat_by = batch_size else: @@ -2524,9 +2510,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @@ -2557,9 +2541,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.num_channels_unet = pipeline.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor @@ -2642,59 +2624,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() # (3) Prepare conditional inputs for controlnet using the guider data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2703,11 +2638,26 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -2717,51 +2667,74 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents @@ -2775,9 +2748,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data)