From 2af7baa040e9a07405e05e5bd4abecf15f595b9f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:18:29 +0100 Subject: [PATCH 1/2] Remove `*pooled_*` mentions from Chroma inpaint (#13026) Remove `*pooled_*` mentions from Chroma as it has just one TE. --- .../pipelines/chroma/pipeline_chroma_inpainting.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py index 019c144152..3ea1ece36c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py @@ -482,8 +482,6 @@ class ChromaInpaintPipeline( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, max_sequence_length=None, @@ -531,15 +529,6 @@ class ChromaInpaintPipeline( f" {negative_prompt_embeds.shape}." ) - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") @@ -793,13 +782,11 @@ class ChromaInpaintPipeline( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, From 956bdcc3ea4897eaeb6c828b8433bdcae71e9f0f Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 26 Jan 2026 18:41:48 +0100 Subject: [PATCH 2/2] Flag Flax schedulers as deprecated (#13031) flag flax schedulers as deprecated --- .../scheduling_dpmsolver_multistep_flax.py | 38 ++++++++++++++++--- .../scheduling_euler_discrete_flax.py | 33 ++++++++++++++-- .../schedulers/scheduling_karras_ve_flax.py | 10 ++++- .../scheduling_lms_discrete_flax.py | 33 ++++++++++++++-- .../schedulers/scheduling_pndm_flax.py | 30 ++++++++++++--- .../schedulers/scheduling_sde_ve_flax.py | 21 ++++++++-- 6 files changed, 141 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 71b9960bf2..66398073b2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState @@ -171,6 +175,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: @@ -203,7 +211,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ) def set_timesteps( - self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + self, + state: DPMSolverMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -301,10 +312,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): if self.config.thresholding: # Dynamic thresholding in https://huggingface.co/papers/2205.11487 dynamic_max_val = jnp.percentile( - jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + jnp.abs(x0_pred), + self.config.dynamic_thresholding_ratio, + axis=tuple(range(1, x0_pred.ndim)), ) dynamic_max_val = jnp.maximum( - dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + dynamic_max_val, + self.config.sample_max_value * jnp.ones_like(dynamic_max_val), ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred @@ -385,7 +399,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + lambda_t, lambda_s0, lambda_s1 = ( + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 @@ -443,7 +461,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + t, s0, s1, s2 = ( + prev_timestep, + timestep_list[-1], + timestep_list[-2], + timestep_list[-3], + ) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], @@ -615,7 +638,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( - self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DPMSolverMultistepSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py index 09341c909d..2bb6bf3558 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py @@ -19,6 +19,7 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -28,6 +29,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class EulerDiscreteSchedulerState: common: CommonSchedulerState @@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -99,6 +112,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: @@ -146,7 +163,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: EulerDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -159,7 +179,12 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ if self.config.timestep_spacing == "linspace": - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index bacfbd6100..3f43a5fa99 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,10 +22,13 @@ import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .scheduling_utils_flax import FlaxSchedulerMixin +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class KarrasVeSchedulerState: # setable values @@ -102,7 +105,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): s_min: float = 0.05, s_max: float = 50, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): return KarrasVeSchedulerState.create() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 3fd4dc8a5d..4edb091348 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,6 +20,7 @@ import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -29,6 +30,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState @@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -101,6 +114,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: @@ -165,7 +182,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): return integrated_coeff def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: LMSDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -177,7 +197,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): the number of diffusion steps used when generating samples with a pre-trained model. """ - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 44bafccd55..bbef4649ec 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState @@ -131,6 +135,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method @@ -190,7 +198,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( - jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + jnp.array( + [0, self.config.num_train_timesteps // num_inference_steps // 2], + dtype=jnp.int32, + ), self.pndm_order, ) @@ -218,7 +229,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: PNDMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -320,7 +334,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ) diff_to_prev = jnp.where( - state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + state.counter % 2, + 0, + self.config.num_train_timesteps // state.num_inference_steps // 2, ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] @@ -401,7 +417,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( - state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + state.counter == 1, + timestep + self.config.num_train_timesteps // state.num_inference_steps, + timestep, ) # Reference: @@ -466,7 +484,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 09cd081462..f4fe6d8f6b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -23,7 +23,15 @@ import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from ..utils import logging +from .scheduling_utils_flax import ( + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +logger = logging.get_logger(__name__) @flax.struct.dataclass @@ -95,7 +103,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): sampling_eps: float = 1e-5, correct_steps: int = 1, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): state = ScoreSdeVeSchedulerState.create() @@ -108,7 +119,11 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + shape: Tuple = (), + sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.