From a063503044c32c8f0bc386b833a750df3187f0bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 7 Sep 2024 14:46:29 +0200 Subject: [PATCH] apply suggestions from review --- examples/cogvideo/train_cogvideox_lora.py | 45 ++++++++++--------- .../transformers/cogvideox_transformer_3d.py | 22 ++++++++- .../pipelines/cogvideo/pipeline_cogvideox.py | 15 ++++++- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index c800cec6f3..535b81c06f 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -717,25 +717,11 @@ def log_validation( } ) - del pipe - clear_objs_and_retain_memory() + clear_objs_and_retain_memory([pipe]) return videos -def collate_fn(examples): - videos = [example["instance_video"] for example in examples] - prompts = [example["instance_prompt"] for example in examples] - - videos = torch.stack(videos) - videos = videos.to(memory_format=torch.contiguous_format).float() - - return { - "videos": videos, - "prompts": prompts, - } - - def _get_t5_prompt_embeds( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, @@ -993,7 +979,6 @@ def main(args): weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - print("weight_dtype:", weight_dtype) if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. @@ -1159,6 +1144,27 @@ def main(args): id_token=args.id_token, ) + def encode_video(video): + print(video.shape) + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(video).latent_dist + return latent_dist + + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + def collate_fn(examples): + videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + videos = torch.cat(videos) + videos = videos.to(memory_format=torch.contiguous_format).float() + + return { + "videos": videos, + "prompts": prompts, + } + train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size, @@ -1281,7 +1287,7 @@ def main(args): models_to_accumulate.extend([text_encoder]) with accelerator.accumulate(models_to_accumulate): - videos = batch["videos"].to(dtype=vae.dtype) + model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] prompts = batch["prompts"] # encode prompts @@ -1294,11 +1300,6 @@ def main(args): requires_grad=args.train_text_encoder, ) - # Convert videos to latents - videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor - model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] - # Sample noise that will be added to the latents noise = torch.rand_like(model_input) batch_size, num_frames, num_channels, height, width = model_input.shape diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 12435fa340..69f3240144 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -20,7 +20,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import is_torch_version, logging +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -403,8 +403,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding @@ -470,6 +486,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index be83eef3bc..7e53fcb356 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer @@ -486,6 +486,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -511,12 +515,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, - lora_scale: Optional[float] = None, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -573,6 +577,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -617,6 +625,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -635,6 +644,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, negative_prompt, @@ -699,6 +709,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float()