mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
apply suggestions from review
This commit is contained in:
@@ -717,25 +717,11 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
del pipe
|
||||
clear_objs_and_retain_memory()
|
||||
clear_objs_and_retain_memory([pipe])
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
videos = torch.stack(videos)
|
||||
videos = videos.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return {
|
||||
"videos": videos,
|
||||
"prompts": prompts,
|
||||
}
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
@@ -993,7 +979,6 @@ def main(args):
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
print("weight_dtype:", weight_dtype)
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
@@ -1159,6 +1144,27 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video):
|
||||
print(video.shape)
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
|
||||
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
videos = torch.cat(videos)
|
||||
videos = videos.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return {
|
||||
"videos": videos,
|
||||
"prompts": prompts,
|
||||
}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
@@ -1281,7 +1287,7 @@ def main(args):
|
||||
models_to_accumulate.extend([text_encoder])
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
videos = batch["videos"].to(dtype=vae.dtype)
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode prompts
|
||||
@@ -1294,11 +1300,6 @@ def main(args):
|
||||
requires_grad=args.train_text_encoder,
|
||||
)
|
||||
|
||||
# Convert videos to latents
|
||||
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor
|
||||
model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
|
||||
# Sample noise that will be added to the latents
|
||||
noise = torch.rand_like(model_input)
|
||||
batch_size, num_frames, num_channels, height, width = model_input.shape
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
@@ -403,8 +403,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
@@ -470,6 +486,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
@@ -486,6 +486,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -511,12 +515,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
lora_scale: Optional[float] = None,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -573,6 +577,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
@@ -617,6 +625,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -635,6 +644,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -699,6 +709,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
Reference in New Issue
Block a user