From cbc2ec8f44449cbc888256499d71bb6d7196aaa2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 28 Aug 2024 14:48:12 +0530 Subject: [PATCH 01/12] AnimateDiff prompt travel (#9231) * update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * undo memory optimization changes * update docstrings * fix tests * fix pia tests * apply suggestions from review * add tests * update comment --- src/diffusers/models/attention.py | 25 ++- src/diffusers/models/controlnet_sparsectrl.py | 1 - .../models/unets/unet_motion_model.py | 3 +- .../animatediff/pipeline_animatediff.py | 66 +++++-- .../pipeline_animatediff_controlnet.py | 64 +++++-- .../animatediff/pipeline_animatediff_sdxl.py | 2 + .../pipeline_animatediff_sparsectrl.py | 2 + .../pipeline_animatediff_video2video.py | 180 +++++++++++------- src/diffusers/pipelines/free_noise_utils.py | 178 ++++++++++++++++- .../pag/pipeline_pag_sd_animatediff.py | 6 +- src/diffusers/pipelines/pia/pipeline_pia.py | 2 + tests/models/unets/test_models_unet_motion.py | 2 +- .../pipelines/animatediff/test_animatediff.py | 23 +++ .../test_animatediff_controlnet.py | 21 ++ .../test_animatediff_video2video.py | 25 +++ 15 files changed, 475 insertions(+), 125 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6858d842c..7766442f71 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -972,15 +972,32 @@ class FreeNoiseTransformerBlock(nn.Module): return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "pyramid": + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": if num_frames % 2 == 0: # num_frames = 4 => [1, 2, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) + mid = num_frames // 2 + weights = list(range(1, mid + 1)) weights = weights + weights[::-1] else: # num_frames = 5 => [1, 2, 3, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + [num_frames // 2 + 1] + weights[::-1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) else: raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index e91551c709..fa37e1f9e3 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -691,7 +691,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(sample_num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 73c9c70c4a..89cdb76741 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module): self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks @@ -2178,7 +2178,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index a1f0374e31..cb6f50f43c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -432,7 +432,6 @@ class AnimateDiffPipeline( extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -470,8 +469,8 @@ class AnimateDiffPipeline( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -557,11 +556,15 @@ class AnimateDiffPipeline( def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, @@ -701,9 +704,10 @@ class AnimateDiffPipeline( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -716,22 +720,39 @@ class AnimateDiffPipeline( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -783,6 +804,9 @@ class AnimateDiffPipeline( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 6e8b0e3e5f..5357d6d5b8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline( def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index a466823475..e531c91c16 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1143,6 +1143,8 @@ class AnimateDiffSDXLPipeline( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index e9e0d518c8..8b037cdc34 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -878,6 +878,8 @@ class AnimateDiffSparseControlNetPipeline( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + # 4. Prepare IP-Adapter embeddings if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 70a4201ca0..1ebe2b9b60 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, @@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline( else: scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline( def prepare_latents( self, - video, - height, - width, - num_channels_latents, - batch_size, - timestep, - dtype, - device, - generator, - latents=None, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, decode_chunk_size: int = 16, - ): - if latents is None: - num_frames = video.shape[1] - else: - num_frames = latents.shape[2] - + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] shape = ( batch_size, num_channels_latents, @@ -708,8 +704,13 @@ class AnimateDiffVideoToVideoPipeline( if shape != latents.shape: # [B, C, F, H, W] raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + latents = latents.to(device, dtype=dtype) + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + return latents @property @@ -735,6 +736,10 @@ class AnimateDiffVideoToVideoPipeline( def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -743,6 +748,7 @@ class AnimateDiffVideoToVideoPipeline( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + enforce_inference_steps: bool = False, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, @@ -874,9 +880,10 @@ class AnimateDiffVideoToVideoPipeline( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -884,29 +891,85 @@ class AnimateDiffVideoToVideoPipeline( batch_size = prompt_embeds.shape[0] device = self._execution_device + dtype = self.dtype - # 3. Encode input prompt + # 3. Prepare timesteps + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, @@ -916,38 +979,10 @@ class AnimateDiffVideoToVideoPipeline( self.do_classifier_free_guidance, ) - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - - # 5. Prepare latent variables - if latents is None: - video = self.video_processor.preprocess_video(video, height=height, width=width) - # Move the number of frames before the number of channels. - video = video.permute(0, 2, 1, 3, 4) - video = video.to(device=device, dtype=prompt_embeds.dtype) - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - video=video, - height=height, - width=width, - num_channels_latents=num_channels_latents, - batch_size=batch_size * num_videos_per_prompt, - timestep=latent_timestep, - dtype=prompt_embeds.dtype, - device=device, - generator=generator, - latents=latents, - decode_chunk_size=decode_chunk_size, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Add image embeds for IP-Adapter + # 8. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None @@ -967,9 +1002,12 @@ class AnimateDiffVideoToVideoPipeline( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - # 8. Denoising loop + # 9. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1005,14 +1043,14 @@ class AnimateDiffVideoToVideoPipeline( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 9. Post-processing + # 10. Post-processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) - # 10. Offload all models + # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 1ee3b6d0a9..f2763f1c33 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -22,6 +22,7 @@ from ..models.unets.unet_motion_model import ( DownBlockMotion, UpBlockMotion, ) +from ..pipelines.pipeline_utils import DiffusionPipeline from ..utils import logging from ..utils.torch_utils import randn_tensor @@ -98,6 +99,142 @@ class AnimateDiffFreeNoiseMixin: free_noise_transfomer_block.state_dict(), strict=True ) + def _check_inputs_free_noise( + self, + prompt, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + num_frames, + ) -> None: + if not isinstance(prompt, (str, dict)): + raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, dict)): + raise ValueError( + f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" + ) + + if prompt_embeds is not None or negative_prompt_embeds is not None: + raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") + + frame_indices = [isinstance(x, int) for x in prompt.keys()] + frame_prompts = [isinstance(x, str) for x in prompt.values()] + min_frame = min(list(prompt.keys())) + max_frame = max(list(prompt.keys())) + + if not all(frame_indices): + raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") + if not all(frame_prompts): + raise ValueError("Expected str values in `prompt` dict for FreeNoise.") + if min_frame != 0: + raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") + if max_frame >= num_frames: + raise ValueError( + f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." + ) + + def _encode_prompt_free_noise( + self, + prompt: Union[str, Dict[int, str]], + num_frames: int, + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, Dict[int, str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.Tensor: + if negative_prompt is None: + negative_prompt = "" + + # Ensure that we have a dictionary of prompts + if isinstance(prompt, str): + prompt = {0: prompt} + if isinstance(negative_prompt, str): + negative_prompt = {0: negative_prompt} + + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) + + # Sort the prompts based on frame indices + prompt = dict(sorted(prompt.items())) + negative_prompt = dict(sorted(negative_prompt.items())) + + # Ensure that we have a prompt for the last frame index + prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] + negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] + + frame_indices = list(prompt.keys()) + frame_prompts = list(prompt.values()) + frame_negative_indices = list(negative_prompt.keys()) + frame_negative_prompts = list(negative_prompt.values()) + + # Generate and interpolate positive prompts + prompt_embeds, _ = self.encode_prompt( + prompt=frame_prompts, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + shape = (num_frames, *prompt_embeds.shape[1:]) + prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) + + for i in range(len(frame_indices) - 1): + start_frame = frame_indices[i] + end_frame = frame_indices[i + 1] + start_tensor = prompt_embeds[i].unsqueeze(0) + end_tensor = prompt_embeds[i + 1].unsqueeze(0) + + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( + start_frame, end_frame, start_tensor, end_tensor + ) + + # Generate and interpolate negative prompts + negative_prompt_embeds = None + negative_prompt_interpolation_embeds = None + + if do_classifier_free_guidance: + _, negative_prompt_embeds = self.encode_prompt( + prompt=[""] * len(frame_negative_prompts), + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=True, + negative_prompt=frame_negative_prompts, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) + + for i in range(len(frame_negative_indices) - 1): + start_frame = frame_negative_indices[i] + end_frame = frame_negative_indices[i + 1] + start_tensor = negative_prompt_embeds[i].unsqueeze(0) + end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) + + negative_prompt_interpolation_embeds[ + start_frame : end_frame + 1 + ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + prompt_embeds = prompt_interpolation_embeds + negative_prompt_embeds = negative_prompt_interpolation_embeds + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds, negative_prompt_embeds + def _prepare_latents_free_noise( self, batch_size: int, @@ -172,12 +309,29 @@ class AnimateDiffFreeNoiseMixin: latents = latents[:, :, :num_frames] return latents + def _lerp( + self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor + ) -> torch.Tensor: + num_indices = end_index - start_index + 1 + interpolated_tensors = [] + + for i in range(num_indices): + alpha = i / (num_indices - 1) + interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor + interpolated_tensors.append(interpolated_tensor) + + interpolated_tensors = torch.cat(interpolated_tensors) + return interpolated_tensors + def enable_free_noise( self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", + prompt_interpolation_callback: Optional[ + Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] + ] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -195,13 +349,27 @@ class AnimateDiffFreeNoiseMixin: weighting_scheme (`str`, defaults to `pyramid`): Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting schemes are supported currently: + - "flat" + Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1]. - "pyramid" - Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + - "delayed_reverse_sawtooth" + Performs weighted averaging with low weights for earlier frames and high-to-low weights for + later frames: [0.01, 0.01, 3, 2, 1]. noise_type (`str`, defaults to "shuffle_context"): - TODO + Must be one of ["shuffle_context", "repeat_context", "random"]. + - "shuffle_context" + Shuffles a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. This is usually the best setting for most generation scenarious. However, there + might be visible repetition noticeable in the kinds of motion/animation generated. + - "repeated_context" + Repeats a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. + - "random" + The final latents are random without any repetition. """ - allowed_weighting_scheme = ["pyramid"] + allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: @@ -219,6 +387,7 @@ class AnimateDiffFreeNoiseMixin: self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type + self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp if hasattr(self.unet.mid_block, "motion_modules"): blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] @@ -229,6 +398,7 @@ class AnimateDiffFreeNoiseMixin: self._enable_free_noise_in_block(block) def disable_free_noise(self) -> None: + r"""Disable the FreeNoise sampling mechanism.""" self._free_noise_context_length = None if hasattr(self.unet.mid_block, "motion_modules"): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index 73c53b3658..1e81fa3a15 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -734,6 +734,8 @@ class AnimateDiffPAGPipeline( elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, @@ -805,7 +807,9 @@ class AnimateDiffPAGPipeline( with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f0e8cfb03d..b7dfcd39ed 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -824,6 +824,8 @@ class PIAPipeline( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index 53833d6a07..ee05f0d938 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -51,7 +51,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device) return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 1354ac9ff1..618a5cff99 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,29 @@ class AnimateDiffPipelineFastTests( "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_frames"] = 16 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device) + inputs["num_frames"] = 16 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 3035fc1e3c..c0ad223c6c 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -476,6 +476,27 @@ class AnimateDiffControlNetPipelineFastTests( "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + def test_vae_slicing(self, video_count=2): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index cd33bf0891..c49790e0f2 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -491,3 +491,28 @@ class AnimateDiffVideoToVideoPipelineFastTests( 1e-4, "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] From 089cf798eb199ddc0d396c7bbb0172fecf2845e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 28 Aug 2024 12:39:45 -0500 Subject: [PATCH 02/12] Change default for `guidance_scale`in FLUX (#9305) To match the original code, 7.0 is too high --- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 3b6c7982ff..bb214885da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -536,7 +536,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, - guidance_scale: float = 7.0, + guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, From 2a3fbc2cc269aa0c0d5cfdfaa3564d769d92b882 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 29 Aug 2024 07:41:46 +0530 Subject: [PATCH 03/12] [LoRA] support kohya and xlabs loras for flux. (#9295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support kohya lora in flux. * format * support xlabs * diffusion_model prefix. * Apply suggestions from code review Co-authored-by: apolinário * empty commit. Co-authored-by: Leommm-byte --------- Co-authored-by: apolinário Co-authored-by: Leommm-byte --- .../loaders/lora_conversion_utils.py | 293 ++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 21 +- 2 files changed, 313 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index df1d351ca1..4b54269479 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -14,6 +14,8 @@ import re +import torch + from ..utils import is_peft_version, logging @@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" return {new_name: alpha} + + +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 +# Some utilities were reused from +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): + new_state_dict = {} + orig_keys = list(old_state_dict.keys()) + + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + down_weight = sds_sd.pop(sds_key) + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + + for old_key in orig_keys: + # Handle double_blocks + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.transformer_blocks.{block_num}" + + if "processor.proj_lora1" in old_key: + new_key += ".attn.to_out.0" + elif "processor.proj_lora2" in old_key: + new_key += ".attn.to_add_out" + elif "processor.qkv_lora1" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", + ], + ) + # continue + elif "processor.qkv_lora2" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.to_q", + f"transformer.transformer_blocks.{block_num}.attn.to_k", + f"transformer.transformer_blocks.{block_num}.attn.to_v", + ], + ) + # continue + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + # Handle single_blocks + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.single_transformer_blocks.{block_num}" + + if "proj_lora1" in old_key or "proj_lora2" in old_key: + new_key += ".proj_out" + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + new_key += ".norm.linear" + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + else: + # Handle other potential key patterns here + new_key = old_key + + # Since we already handle qkv above. + if "qkv" not in old_key: + new_state_dict[new_key] = old_state_dict.pop(old_key) + + if len(old_state_dict) > 0: + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") + + return new_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index cefe66bc8c..7d644d6841 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -31,7 +31,12 @@ from ..utils import ( scale_lora_layers, ) from .lora_base import LoraBaseMixin -from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers +from .lora_conversion_utils import ( + _convert_kohya_flux_lora_to_diffusers, + _convert_non_diffusers_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) if is_transformers_available(): @@ -1583,6 +1588,20 @@ class FluxLoraLoaderMixin(LoraBaseMixin): allow_pickle=allow_pickle, ) + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return (state_dict, None) if return_alphas else state_dict + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return (state_dict, None) if return_alphas else state_dict + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) From 40c13fe5b41872d99dde75bf32f5c31626cf0043 Mon Sep 17 00:00:00 2001 From: Anand Kumar <63339285+AnandK27@users.noreply.github.com> Date: Thu, 29 Aug 2024 01:53:36 -0700 Subject: [PATCH 04/12] [train_custom_diffusion.py] Fix the LR schedulers when `num_train_epochs` is passed in a distributed training env (#9308) * Update train_custom_diffusion.py to fix the LR schedulers for `num_train_epochs` * Fix saving text embeddings during safe serialization * Fixed formatting --- .../train_custom_diffusion.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 8dddcd0ca7..e498ca98b1 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di for x, y in zip(modifier_token_id, args.modifier_token): learned_embeds_dict = {} learned_embeds_dict[y] = learned_embeds[x] - filename = f"{output_dir}/{y}.bin" if safe_serialization: + filename = f"{output_dir}/{y}.safetensors" safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"}) else: + filename = f"{output_dir}/{y}.bin" torch.save(learned_embeds_dict, filename) @@ -1040,17 +1041,22 @@ def main(args): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, ) # Prepare everything with our `accelerator`. @@ -1065,8 +1071,14 @@ def main(args): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From 4f495b06dcbbc3437a598a20718fe74c29308756 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 28 Aug 2024 23:31:47 -1000 Subject: [PATCH 05/12] rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312) fix notes and dtype --- src/diffusers/models/embeddings.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dcb9528cb1..1f29622bdf 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed( t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: + # stable audio freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis @@ -590,11 +593,11 @@ def apply_rotary_emb( cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: - # Use for example in Lumina + # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: - # Use for example in Stable Audio + # Used for Stable Audio x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: @@ -604,6 +607,7 @@ def apply_rotary_emb( return out else: + # used for lumina x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) From 61d96c3ae756e114d2c88089d6e5c11b18501fe8 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 29 Aug 2024 09:37:15 -1000 Subject: [PATCH 06/12] refactor rotary embedding 3: so it is not on cpu (#9307) change get_1d_rotary to accept pos as torch tensors --- src/diffusers/models/embeddings.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1f29622bdf..5e9863ab0d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed( assert dim % 2 == 0 if isinstance(pos, int): - pos = np.arange(pos) + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] + freqs = freqs.to(pos.device) + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] @@ -626,7 +629,7 @@ class FluxPosEmbed(nn.Module): n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.squeeze().float().cpu().numpy() + pos = ids.squeeze().float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): From 1d4d71875b495a5322f85bef2e99ac2bfa802e7a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 30 Aug 2024 10:23:50 +0530 Subject: [PATCH 07/12] [CI] Update Hub Token on nightly tests (#9318) update --- .github/workflows/nightly_tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index ae5f36e5bd..be3381babd 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -79,7 +79,7 @@ jobs: python utils/print_env.py - name: Pipeline CUDA Test env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 run: | @@ -139,7 +139,7 @@ jobs: - name: Run nightly PyTorch CUDA tests for non-pipeline modules if: ${{ matrix.module != 'examples'}} env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 run: | @@ -152,7 +152,7 @@ jobs: - name: Run nightly example tests with Torch if: ${{ matrix.module == 'examples' }} env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 run: | @@ -209,7 +209,7 @@ jobs: - name: Run nightly Flax TPU tests env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | python -m pytest -n 0 \ -s -v -k "Flax" \ @@ -264,7 +264,7 @@ jobs: - name: Run Nightly ONNXRuntime CUDA tests env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "Onnx" \ From e417d028115e72b953a73e39d9687aa70ba3e37e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 30 Aug 2024 13:53:25 +0530 Subject: [PATCH 08/12] [docs] Add a note on torchao/quanto benchmarks for CogVideoX and memory-efficient inference (#9296) * add a note on torchao/quanto benchmarks and memory-efficient inference * apply suggestions from review * update * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add note on enable sequential cpu offload --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index c7340eff40..4254246fee 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -77,10 +77,21 @@ CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds o - `pipe.enable_model_cpu_offload()`: - Without enabling cpu offloading, memory usage is `33 GB` - With enabling cpu offloading, memory usage is `19 GB` +- `pipe.enable_sequential_cpu_offload()`: + - Similar to `enable_model_cpu_offload` but can significantly reduce memory usage at the cost of slow inference + - When enabled, memory usage is under `4 GB` - `pipe.vae.enable_tiling()`: - With enabling cpu offloading and tiling, memory usage is `11 GB` - `pipe.vae.enable_slicing()` +### Quantized inference + +[torchao](https://github.com/pytorch/ao) and [optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be used to quantize the text encoder, transformer and VAE modules to lower the memory requirements. This makes it possible to run the model on a free-tier T4 Colab or lower VRAM GPUs! + +It is also worth noting that torchao quantization is fully compatible with [torch.compile](/optimization/torch2.0#torchcompile), which allows for much faster inference speed. Additionally, models can be serialized and stored in a quantized datatype to save disk space with torchao. Find examples and benchmarks in the gists below. +- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897) +- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa) + ## CogVideoXPipeline [[autodoc]] CogVideoXPipeline From d8a16635f47ac455abd61879bcc6be32dfeaa561 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 30 Aug 2024 08:51:21 -1000 Subject: [PATCH 09/12] update runway repo for single_file (#9323) update to a place holder --- src/diffusers/loaders/single_file_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index f13fcf2387..d620c15e83 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -91,11 +91,11 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, - "inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"}, + "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"}, "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, - "v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"}, + "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"}, "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, "stable_cascade_stage_b_lite": { "pretrained_model_name_or_path": "stabilityai/stable-cascade", From af6c0fb7661faea7ef2dc598cc4b2bf63a943d04 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Sep 2024 15:48:37 +0530 Subject: [PATCH 10/12] [core] CogVideoX memory optimizations in VAE encode (#9340) fake context parallel cache, vae encode tiling (cherry picked from commit bf890bca0e8aed875d6a207f9b826ce894901522) --- .../autoencoders/autoencoder_kl_cogvideox.py | 105 +++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 17fa2bbf40..fe887b7db0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different # number of temporal frames. self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 # We make the minimum height and width of sample for tiling half that of the generally supported self.tile_sample_min_height = sample_height // 2 @@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + enc = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = self.encoder(x_intermediate) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + self._clear_fake_context_parallel_cache() + enc = torch.cat(enc, dim=2) + + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: - The latent representations of the encoded images. If `return_dict` is True, a + The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x) - if self.quant_conv is not None: - h = self.quant_conv(h) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ) return b + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. From 0e6a8403f6b4d2a2778c12d1e76588d00a8d8f1a Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Sep 2024 16:54:58 +0530 Subject: [PATCH 11/12] [core] Support VideoToVideo with CogVideoX (#9333) * add vid2vid pipeline for cogvideox * make fix-copies * update docs * fake context parallel cache, vae encode tiling * add test for cog vid2vid * use video link from HF docs repo * add copied from comments; correctly rename test class --- docs/source/en/api/pipelines/cogvideox.md | 8 +- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/cogvideo/__init__.py | 2 + .../pipelines/cogvideo/pipeline_cogvideox.py | 19 +- .../pipeline_cogvideox_video2video.py | 812 ++++++++++++++++++ .../pipelines/cogvideo/pipeline_output.py | 20 + .../dummy_torch_and_transformers_objects.py | 15 + .../cogvideox/test_cogvideox_video2video.py | 328 +++++++ 9 files changed, 1190 insertions(+), 20 deletions(-) create mode 100644 src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py create mode 100644 src/diffusers/pipelines/cogvideo/pipeline_output.py create mode 100644 tests/pipelines/cogvideox/test_cogvideox_video2video.py diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 4254246fee..41a0fd0220 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -98,6 +98,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc - all - __call__ +## CogVideoXVideoToVideoPipeline + +[[autodoc]] CogVideoXVideoToVideoPipeline + - all + - __call__ + ## CogVideoXPipelineOutput -[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput +[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 44ea224881..bb8ceccb76 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -255,6 +255,7 @@ else: "BlipDiffusionPipeline", "CLIPImageProjection", "CogVideoXPipeline", + "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", "FluxControlNetPipeline", "FluxPipeline", @@ -699,6 +700,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AuraFlowPipeline, CLIPImageProjection, CogVideoXPipeline, + CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, FluxControlNetPipeline, FluxPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 63436e9be6..a999e0441d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -132,7 +132,7 @@ else: "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] - _import_structure["cogvideo"] = ["CogVideoXPipeline"] + _import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -454,7 +454,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline - from .cogvideo import CogVideoXPipeline + from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py index d155d3ef51..baf0de3482 100644 --- a/src/diffusers/pipelines/cogvideo/__init__.py +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_cogvideox import CogVideoXPipeline + from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 11f491e495..3af47c1774 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -15,7 +15,6 @@ import inspect import math -from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -26,9 +25,10 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -136,21 +136,6 @@ def retrieve_timesteps( return timesteps, num_inference_steps -@dataclass -class CogVideoXPipelineOutput(BaseOutput): - r""" - Output class for CogVideo pipelines. - - Args: - frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing - denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape - `(batch_size, num_frames, channels, height, width)`. - """ - - frames: torch.Tensor - - class CogVideoXPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using CogVideoX. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py new file mode 100644 index 0000000000..16686d1ab7 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -0,0 +1,812 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + >>> input_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe( + ... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50 + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXVideoToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for video-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + ): + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + init_latents = self.vae.config.scaling_factor * init_latents + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + negative_prompt, + callback_on_step_end_tensor_inputs, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: List[Image.Image] = None, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + strength: float = 0.8, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PIL.Image.Image]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + strength, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=prompt_embeds.dtype) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_output.py b/src/diffusers/pipelines/cogvideo/pipeline_output.py new file mode 100644 index 0000000000..3de030dd69 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 477beed49f..644a148a8b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -272,6 +272,21 @@ class CogVideoXPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class CogVideoXVideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/cogvideox/test_cogvideox_video2video.py b/tests/pipelines/cogvideox/test_cogvideox_video2video.py new file mode 100644 index 0000000000..27f0c8441c --- /dev/null +++ b/tests/pipelines/cogvideox/test_cogvideox_video2video.py @@ -0,0 +1,328 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) + + +enable_full_determinism() + + +class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogVideoXVideoToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"video"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3DModel( + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings + # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel + # to be 32. The internal dim is product of num_attention_heads and attention_head_dim + num_attention_heads=4, + attention_head_dim=8, + in_channels=4, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, # Must match with tiny-random-t5 + num_layers=1, + sample_width=16, # latent width: 2 -> final width: 16 + sample_height=16, # latent height: 2 -> final height: 16 + sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, + ) + + torch.manual_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video_height = 16 + video_width = 16 + video = [Image.new("RGB", (video_width, video_height))] * num_frames + + inputs = { + "video": video, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "strength": 0.5, + "guidance_scale": 6.0, + # Cannot reduce because convolution kernel becomes bigger than sample + "height": video_height, + "width": video_width, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Since VideoToVideo uses both encoder and decoder tiling, there seems to be much more numerical + # difference. We seem to need a higher tolerance here... + # TODO(aryan): Look into this more deeply + expected_diff_max = 0.4 + + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + @unittest.skip("xformers attention processor does not exist for CogVideoX") + def test_xformers_attention_forwardGenerator_pass(self): + pass + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_fused = frames[0, -2:, -1, -3:, -3:] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_disabled = frames[0, -2:, -1, -3:, -3:] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." From 007ad0e2aa792a192a100cf0dcb0e100225fd486 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 2 Sep 2024 17:51:48 +0530 Subject: [PATCH 12/12] [CI] More fixes for Fast GPU Tests on main (#9300) update --- tests/models/test_modeling_common.py | 3 +++ tests/models/transformers/test_models_transformer_flux.py | 3 +++ tests/pipelines/flux/test_pipeline_flux.py | 3 +++ tests/pipelines/pag/test_pag_sd3.py | 1 + tests/pipelines/stable_audio/test_stable_audio.py | 2 ++ 5 files changed, 12 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0ce01fb93f..2437a5a55c 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -417,6 +417,9 @@ class ModelTesterMixin: @require_torch_gpu def test_set_attn_processor_for_determinism(self): + if self.uses_custom_attn_processor: + return + torch.use_deterministic_algorithms(False) if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 538d158cbc..6cf7a4f757 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -32,6 +32,9 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): # We override the items here because the transformer under consideration is small. model_split_percents = [0.7, 0.6, 0.6] + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + @property def dummy_input(self): batch_size = 1 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 57aacd1648..4caff40302 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -25,6 +25,9 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + # there is no xformers processor for Flux + test_xformers_attention = False + def get_dummy_components(self): torch.manual_seed(0) transformer = FluxTransformer2DModel( diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py index 93260870e7..627d613ee2 100644 --- a/tests/pipelines/pag/test_pag_sd3.py +++ b/tests/pipelines/pag/test_pag_sd3.py @@ -37,6 +37,7 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py index fe8a684de0..41ac94891c 100644 --- a/tests/pipelines/stable_audio/test_stable_audio.py +++ b/tests/pipelines/stable_audio/test_stable_audio.py @@ -68,6 +68,8 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "callback_steps", ] ) + # There is not xformers version of the StableAudioPipeline custom attention processor + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0)