mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into transformers-v5-pr
This commit is contained in:
@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
|
||||
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
|
||||
# Issue: https://github.com/huggingface/diffusers/issues/12975
|
||||
if self.training:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -400,6 +400,7 @@ class LongCatImageTransformer2DModel(
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Longcat-Image.
|
||||
|
||||
@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep=None):
|
||||
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
|
||||
if prev_timestep is None:
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def _batch_get_variance(self, t, prev_t):
|
||||
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
|
||||
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
|
||||
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
|
||||
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
|
||||
correction is necessary because the predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
|
||||
the input and `use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
|
||||
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if variance_noise is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
variance = std_dev_t * variance_noise
|
||||
|
||||
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
timesteps (`torch.Tensor`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDPMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
|
||||
|
||||
|
||||
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDPMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=8e-4)
|
||||
|
||||
@@ -192,6 +192,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user