diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 214b2b953c..cea079251b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -75,6 +75,101 @@ EXAMPLE_DOC_STRING = """ ``` """ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): r""" @@ -96,9 +191,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): Frozen CLIP text encoder. tokenizer_2 ([`CLIPTokenizer`]): Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -125,6 +222,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _encode_prompt_qwen( self, @@ -132,9 +230,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 256, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] # Kandinsky specific prompt template prompt_template = "\n".join([ @@ -180,16 +281,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): embeds = embeds.repeat(1, num_videos_per_prompt, 1) embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return embeds, cu_seqlens + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -208,7 +312,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) - return pooled_embed + return pooled_embed.to(dtype) def encode_prompt( self, @@ -216,34 +320,151 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ device = device or self._execution_device - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) - pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - if do_classifier_free_guidance: + if prompt_embeds is None: + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) + else: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds + + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) - negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + negative_prompt_embeds_clip = self._encode_prompt_clip( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) else: - negative_prompt_embeds = None - negative_pooled_embed = None + negative_prompt_embeds_qwen = None + negative_prompt_embeds_clip = None negative_cu_seqlens = None - text_embeds = { - "text_embeds": prompt_embeds, - "pooled_embed": pooled_embed, + prompt_embeds_dict = { + "text_embeds": prompt_embeds_qwen, + "pooled_embed": prompt_embeds_clip, } - negative_text_embeds = { - "text_embeds": negative_prompt_embeds, - "pooled_embed": negative_pooled_embed, + negative_prompt_embeds_dict = { + "text_embeds": negative_prompt_embeds_qwen, + "pooled_embed": negative_prompt_embeds_clip, } if do_classifier_free_guidance else None - return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, @@ -252,34 +473,31 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): height: int = 480, width: int = 832, num_frames: int = 81, - visual_cond: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - num_latent_frames = latents.shape[1] - latents = latents.to(device=device, dtype=dtype) + return latents.to(device=device, dtype=dtype) - else: - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - if visual_cond: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if self.transformer.visual_cond: # For visual conditioning, concatenate with zeros and mask visual_cond = torch.zeros_like(latents) visual_cond_mask = torch.zeros( @@ -291,26 +509,46 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return latents + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 121, + num_frames: int = 25, num_inference_steps: int = 50, guidance_scale: float = 5.0, scheduler_scale: float = 10.0, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, **kwargs, ): r""" @@ -318,9 +556,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the video generation. + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to avoid during video generation. + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `512`): The height in pixels of the generated video. width (`int`, defaults to `768`): @@ -335,82 +574,109 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): Scale factor for the custom flow matching scheduler. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A torch generator to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`KandinskyPipelineOutput`]. - callback_on_step_end (`Callable`, *optional*): + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. Examples: Returns: [`~KandinskyPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + the first element is a list with the generated images. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Reset embeddings dtype self.transformer.time_embeddings.reset_dtype() self.transformer.text_rope_embeddings.reset_dtype() self.transformer.visual_rope_embeddings.reset_dtype() - - dtype = self.transformer.dtype - if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - - device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, device=device, + dtype=dtype, ) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=16, - height=height, - width=width, - num_frames=num_frames, - visual_cond=self.transformer.visual_cond, - dtype=dtype, - device=device, - generator=generator, - latents=latents, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, ) - - visual_cond = latents[:, :, :, :, 16:] + # 6. Prepare rope positions + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ - torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), - torch.arange(width // 8 // 2, device=device), + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) @@ -421,52 +687,72 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): else None ) + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + timestep = t.unsqueeze(0).flatten() - with torch.autocast(device_type="cuda", dtype=dtype): - pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=text_embeds["text_embeds"], - pooled_projections=text_embeds["pooled_embed"], - timestep=timestep, + + + # Predict noise residual + # with torch.autocast(device_type="cuda", dtype=dtype): + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=True + ).sample + + if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=(1, 2, 2), + text_rope_pos=negative_text_rope_pos, + scale_factor=(1, 2, 2), sparse_params=None, return_dict=True ).sample - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=negative_text_embeds["text_embeds"], - pooled_projections=negative_text_embeds["pooled_embed"], - timestep=timestep, - visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=None, - return_dict=True - ).sample - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) - latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + # Compute previous sample + latents[:, :, :, :, :16] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) - + prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict) + negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing latents = latents[:, :, :, :, :16] # 9. Decode latents to video @@ -477,26 +763,23 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): batch_size, num_videos_per_prompt, (num_frames - 1) // self.vae_scale_factor_temporal + 1, - height // 8, - width // 8, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, 16, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] - video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + video = video.reshape( + batch_size * num_videos_per_prompt, + 16, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial + ) # Normalize and decode video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample - video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format - if output_type == "pil": - if num_frames == 1: - # Single image - video = [ToPILImage()(frame.squeeze(1)) for frame in video] - else: - # Video frames - video = [video[i] for i in range(video.shape[0])] - + video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents