mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[I2vGenXL] clean up things (#6845)
* remove _to_tensor * remove _to_tensor definition * remove _collapse_frames_into_batch * remove lora for not bloating the code. * remove sample_size. * simplify code a bit more * ensure timesteps are always in tensor.
This commit is contained in:
@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _to_tensor(inputs, device):
|
||||
if not torch.is_tensor(inputs):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = device.type == "mps"
|
||||
if isinstance(inputs, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
inputs = torch.tensor([inputs], dtype=dtype, device=device)
|
||||
elif len(inputs.shape) == 0:
|
||||
inputs = inputs[None].to(device)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -174,8 +151,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
@@ -543,7 +518,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
timesteps = _to_tensor(timestep, sample.device)
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
@@ -572,7 +558,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
|
||||
context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)
|
||||
|
||||
image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :])
|
||||
image_latents_for_context_embds = image_latents[:, :, :1, :]
|
||||
image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
|
||||
image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
|
||||
image_latents_for_context_embds.shape[1],
|
||||
image_latents_for_context_embds.shape[3],
|
||||
image_latents_for_context_embds.shape[4],
|
||||
)
|
||||
image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)
|
||||
|
||||
_batch_size, _channels, _height, _width = image_latents_context_embs.shape
|
||||
@@ -586,7 +578,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
context_emb = torch.cat([context_emb, image_emb], dim=1)
|
||||
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
image_latents = _collapse_frames_into_batch(image_latents)
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
|
||||
image_latents.shape[0] * image_latents.shape[2],
|
||||
image_latents.shape[1],
|
||||
image_latents.shape[3],
|
||||
image_latents.shape[4],
|
||||
)
|
||||
image_latents = self.image_latents_proj_in(image_latents)
|
||||
image_latents = (
|
||||
image_latents[None, :]
|
||||
|
||||
@@ -22,18 +22,13 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -207,7 +202,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
@@ -233,23 +227,10 @@ class I2VGenXLPipeline(DiffusionPipeline):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
@@ -380,10 +361,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_image(self, image, device, num_videos_per_prompt):
|
||||
@@ -706,9 +683,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 3.1 Encode input text prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
@@ -716,7 +690,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
Reference in New Issue
Block a user