diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c0b4ad4005..a98cb49114 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor.contiguous()) + # Only use contiguous() during training to avoid DDP gradient stride mismatch warning. + # In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU. + # Issue: https://github.com/huggingface/diffusers/issues/12975 + if self.training: + input_tensor = input_tensor.contiguous() + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 74685607a8..3d38da1dfc 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed @@ -400,6 +400,7 @@ class LongCatImageTransformer2DModel( PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Longcat-Image. diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index d5660471b9..76f0636fbf 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -101,7 +101,7 @@ def betas_for_alpha_bar( # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ return sample - def _get_variance(self, timestep, prev_timestep=None): + def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor: if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return variance - def _batch_get_variance(self, t, prev_t): + def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) @@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: @@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped - predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when - `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would - coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This + correction is necessary because the predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches + the input and `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + Random number generator. variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://huggingface.co/papers/2210.05559) @@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): if variance_noise is None: variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) variance = std_dev_t * variance_noise @@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): + timesteps (`torch.Tensor`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index a3264f54f5..e02b7ea0c0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState @@ -42,7 +46,12 @@ class DDPMSchedulerState: num_inference_steps: Optional[int] = None @classmethod - def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + def create( + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: @@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 8a693e9c2d..df1dd2d987 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @is_flaky() def test_model_cpu_offload_forward_pass(self): super().test_inference_batch_single_identical(expected_max_diff=8e-4) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index 4aafa082e9..a63ff9eaa8 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -192,6 +192,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @slow @require_torch_accelerator