mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into transformers-v5-pr
This commit is contained in:
@@ -482,8 +482,6 @@ class ChromaInpaintPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
max_sequence_length=None,
|
||||
@@ -531,15 +529,6 @@ class ChromaInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||
|
||||
@@ -793,13 +782,11 @@ class ChromaInpaintPipeline(
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DPMSolverMultistepSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -171,6 +175,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
|
||||
@@ -203,7 +211,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple,
|
||||
) -> DPMSolverMultistepSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -301,10 +312,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
|
||||
jnp.abs(x0_pred),
|
||||
self.config.dynamic_thresholding_ratio,
|
||||
axis=tuple(range(1, x0_pred.ndim)),
|
||||
)
|
||||
dynamic_max_val = jnp.maximum(
|
||||
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * jnp.ones_like(dynamic_max_val),
|
||||
)
|
||||
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
return x0_pred
|
||||
@@ -385,7 +399,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
|
||||
lambda_t, lambda_s0, lambda_s1 = (
|
||||
state.lambda_t[t],
|
||||
state.lambda_t[s0],
|
||||
state.lambda_t[s1],
|
||||
)
|
||||
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
@@ -443,7 +461,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
t, s0, s1, s2 = (
|
||||
prev_timestep,
|
||||
timestep_list[-1],
|
||||
timestep_list[-2],
|
||||
timestep_list[-3],
|
||||
)
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
state.lambda_t[t],
|
||||
@@ -615,7 +638,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
|
||||
@@ -19,6 +19,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -28,6 +29,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class EulerDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -99,6 +112,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
|
||||
@@ -146,7 +163,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
self,
|
||||
state: EulerDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
) -> EulerDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -159,7 +179,12 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
|
||||
@@ -22,10 +22,13 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, logging
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class KarrasVeSchedulerState:
|
||||
# setable values
|
||||
@@ -102,7 +105,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
):
|
||||
pass
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
def create_state(self):
|
||||
return KarrasVeSchedulerState.create()
|
||||
|
||||
@@ -20,6 +20,7 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -29,6 +30,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class LMSDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -101,6 +114,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
|
||||
@@ -165,7 +182,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
self,
|
||||
state: LMSDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -177,7 +197,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
low_idx = jnp.floor(timesteps).astype(jnp.int32)
|
||||
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class PNDMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -131,6 +135,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
@@ -190,7 +198,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
else:
|
||||
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
|
||||
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
|
||||
jnp.array(
|
||||
[0, self.config.num_train_timesteps // num_inference_steps // 2],
|
||||
dtype=jnp.int32,
|
||||
),
|
||||
self.pndm_order,
|
||||
)
|
||||
|
||||
@@ -218,7 +229,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -320,7 +334,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
diff_to_prev = jnp.where(
|
||||
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
state.counter % 2,
|
||||
0,
|
||||
self.config.num_train_timesteps // state.num_inference_steps // 2,
|
||||
)
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
@@ -401,7 +417,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
|
||||
timestep = jnp.where(
|
||||
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
|
||||
state.counter == 1,
|
||||
timestep + self.config.num_train_timesteps // state.num_inference_steps,
|
||||
timestep,
|
||||
)
|
||||
|
||||
# Reference:
|
||||
@@ -466,7 +484,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -23,7 +23,15 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -95,7 +103,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
):
|
||||
pass
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
def create_state(self):
|
||||
state = ScoreSdeVeSchedulerState.create()
|
||||
@@ -108,7 +119,11 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
|
||||
self,
|
||||
state: ScoreSdeVeSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
sampling_eps: float = None,
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Reference in New Issue
Block a user