diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index 4a0a878cc0..101fc42089 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -15,9 +15,9 @@ from diffusers import ( AutoencoderKLWan, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, + SanaVideoCausalTransformer3DModel, SanaVideoPipeline, SanaVideoTransformer3DModel, - SanaVideoCausalTransformer3DModel, UniPCMultistepScheduler, ) from diffusers.utils.import_utils import is_accelerate_available diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cf567af802..f55287d9d2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -969,8 +969,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: QwenImageTransformer2DModel, SanaControlNetModel, SanaTransformer2DModel, - SanaVideoTransformer3DModel, SanaVideoCausalTransformer3DModel, + SanaVideoTransformer3DModel, SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, @@ -1208,6 +1208,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LongSanaVideoPipeline, LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, @@ -1245,7 +1246,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: SanaSprintImg2ImgPipeline, SanaSprintPipeline, SanaVideoPipeline, - LongSanaVideoPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 0d94c61cc8..dc59df380d 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import math from typing import Any, Dict, Optional, Tuple, Union @@ -765,7 +764,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): attn_output, kv_cache = attn_result else: attn_output = attn_result - + hidden_states = hidden_states + gate_msa * attn_output # 3. Cross Attention (no cache) @@ -782,7 +781,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width)) - + # Cached conv always supports kv_cache ff_result = self.ff( norm_hidden_states, @@ -793,7 +792,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): ff_output, kv_cache = ff_result else: ff_output = ff_result - + ff_output = ff_output.flatten(1, 3) hidden_states = hidden_states + gate_mlp * ff_output @@ -1248,7 +1247,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi if kv_cache is not None: logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.") kv_cache = None - + for index_block, block in enumerate(self.transformer_blocks): hidden_states = self._gradient_checkpointing_func( block, @@ -1269,7 +1268,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi for index_block, block in enumerate(self.transformer_blocks): # Get kv_cache for this block if available block_kv_cache = kv_cache[index_block] if kv_cache is not None else None - + block_result = block( hidden_states, attention_mask, @@ -1283,7 +1282,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi save_kv_cache=save_kv_cache, kv_cache=block_kv_cache, ) - + # Handle return value (could be tensor or tuple) if isinstance(block_result, tuple): hidden_states, updated_kv_cache = block_result @@ -1291,7 +1290,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi kv_cache[index_block] = updated_kv_cache else: hidden_states = block_result - + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): hidden_states = hidden_states + controlnet_block_samples[index_block - 1] diff --git a/src/diffusers/models/transformers/transformer_sana_video_causal.py b/src/diffusers/models/transformers/transformer_sana_video_causal.py index 03d99aa2af..233c4b67b9 100644 --- a/src/diffusers/models/transformers/transformer_sana_video_causal.py +++ b/src/diffusers/models/transformers/transformer_sana_video_causal.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import math from typing import Any, Dict, Optional, Tuple, Union @@ -30,7 +29,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm -from .transformer_sana_video import SanaVideoTransformerBlock + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -524,7 +523,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): attn_output, kv_cache = attn_result else: attn_output = attn_result - + hidden_states = hidden_states + gate_msa * attn_output # 3. Cross Attention (no cache) @@ -541,7 +540,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width)) - + # Cached conv always supports kv_cache ff_result = self.ff( norm_hidden_states, @@ -552,7 +551,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): ff_output, kv_cache = ff_result else: ff_output = ff_result - + ff_output = ff_output.flatten(1, 3) hidden_states = hidden_states + gate_mlp * ff_output @@ -756,7 +755,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi if kv_cache is not None: logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.") kv_cache = None - + for index_block, block in enumerate(self.transformer_blocks): hidden_states = self._gradient_checkpointing_func( block, @@ -777,7 +776,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi for index_block, block in enumerate(self.transformer_blocks): # Get kv_cache for this block if available block_kv_cache = kv_cache[index_block] if kv_cache is not None else None - + block_result = block( hidden_states, attention_mask, @@ -791,7 +790,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi save_kv_cache=save_kv_cache, kv_cache=block_kv_cache, ) - + # Handle return value (could be tensor or tuple) if isinstance(block_result, tuple): hidden_states, updated_kv_cache = block_result @@ -799,7 +798,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi kv_cache[index_block] = updated_kv_cache else: hidden_states = block_result - + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): hidden_states = hidden_states + controlnet_block_samples[index_block - 1] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e745c5e22f..1163a6e4c0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -757,7 +757,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: SanaSprintImg2ImgPipeline, SanaSprintPipeline, ) - from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline, LongSanaVideoPipeline + from .sana_video import LongSanaVideoPipeline, SanaImageToVideoPipeline, SanaVideoPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana_video/__init__.py b/src/diffusers/pipelines/sana_video/__init__.py index 0a159f4565..7a60edcd61 100644 --- a/src/diffusers/pipelines/sana_video/__init__.py +++ b/src/diffusers/pipelines/sana_video/__init__.py @@ -34,9 +34,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .pipeline_longsana import LongSanaVideoPipeline from .pipeline_sana_video import SanaVideoPipeline from .pipeline_sana_video_i2v import SanaImageToVideoPipeline - from .pipeline_longsana import LongSanaVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/sana_video/pipeline_longsana.py b/src/diffusers/pipelines/sana_video/pipeline_longsana.py index 07d5d410b5..641849a81b 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_longsana.py +++ b/src/diffusers/pipelines/sana_video/pipeline_longsana.py @@ -209,7 +209,7 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): self.vae_scale_factor = self.vae_scale_factor_spatial self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - + # LongSana specific parameters self.base_chunk_frames = base_chunk_frames self.num_cached_blocks = num_cached_blocks @@ -680,11 +680,11 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): original_dtype = flow_pred.dtype flow_pred_f64 = flow_pred.double() xt_f64 = xt.double() - + # Get sigma_t from scheduler sigmas = self.scheduler.sigmas.double().to(flow_pred.device) timesteps_sched = self.scheduler.timesteps.double().to(flow_pred.device) - + # Find closest timestep index # timestep is scalar or [B], we need to match it against scheduler timesteps if timestep.dim() == 0: @@ -692,10 +692,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): timestep_f64 = timestep.double() timestep_id = torch.argmin((timesteps_sched.unsqueeze(0) - timestep_f64.unsqueeze(1)).abs(), dim=1) sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) - + # x_0 = x_t - sigma_t * flow_pred x0_pred = xt_f64 - sigma_t * flow_pred_f64 - + return x0_pred.to(original_dtype) def _create_autoregressive_segments(self, total_frames: int, base_chunk_frames: int) -> List[int]: @@ -751,16 +751,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): """ if chunk_idx == 0: return kv_cache[0] - + cur_kv_cache = kv_cache[chunk_idx] for block_id in range(num_blocks): # Copy temporal cache from previous chunk cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2] - + # Accumulate spatial KV cache from previous chunks cum_vk, cum_k_sum = None, None start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0 - + for i in range(start_chunk_idx, chunk_idx): prev = kv_cache[i][block_id] if prev[0] is not None and prev[1] is not None: @@ -770,13 +770,13 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): else: cum_vk += prev[0] cum_k_sum += prev[1] - + if chunk_idx > 0: assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed" - + cur_kv_cache[block_id][0] = cum_vk cur_kv_cache[block_id][1] = cum_k_sum - + return cur_kv_cache def _get_num_transformer_blocks(self) -> int: @@ -1053,51 +1053,51 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): if denoising_step_list is None: # Use the standard timesteps from the scheduler denoising_step_list = timesteps.cpu().tolist() - + device = latents.device batch_size_latents, _, total_frames, height_latent, width_latent = latents.shape - + # Create autoregressive segments chunk_indices = self._create_autoregressive_segments(total_frames, self.base_chunk_frames) num_chunks = len(chunk_indices) - 1 - + # Get number of transformer blocks num_blocks = self._get_num_transformer_blocks() - + # Initialize KV cache for all chunks kv_cache = self._initialize_kv_cache(num_chunks, num_blocks) - + # Output tensor to store denoised results output = torch.zeros_like(latents) - + transformer_dtype = self.transformer.dtype - + # Process each chunk for chunk_idx in range(num_chunks): start_f = chunk_indices[chunk_idx] end_f = chunk_indices[chunk_idx + 1] - + # Extract chunk latents local_latent = latents[:, :, start_f:end_f].clone() - + # Accumulate KV cache from previous chunks chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks) - + # Multi-step denoising for this chunk with self.progress_bar(total=len(denoising_step_list)) as progress_bar: for step_idx, current_timestep in enumerate(denoising_step_list): if self.interrupt: continue - + # Prepare model input latent_model_input = ( torch.cat([local_latent] * 2) if self.do_classifier_free_guidance else local_latent ) - + # Create timestep tensor t = torch.tensor([current_timestep], device=device, dtype=torch.long) timestep = t.expand(latent_model_input.shape[0]) - + # Forward pass through transformer with KV cache transformer_kwargs = { "encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype), @@ -1107,16 +1107,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): "save_kv_cache": False, # Don't save during denoising steps "kv_cache": chunk_kv_cache, # Pass accumulated KV cache } - + if self.attention_kwargs is not None: transformer_kwargs["attention_kwargs"] = self.attention_kwargs - + # Predict flow model_output = self.transformer( latent_model_input.to(dtype=transformer_dtype), **transformer_kwargs, ) - + # Handle different output formats if isinstance(model_output, tuple): if len(model_output) == 2: @@ -1128,23 +1128,23 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): flow_pred = model_output[0] else: flow_pred = model_output - + flow_pred = flow_pred.float() - + # Perform guidance on flow prediction if self.do_classifier_free_guidance: flow_pred_uncond, flow_pred_text = flow_pred.chunk(2) flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_text - flow_pred_uncond) - + # Handle learned sigma if self.transformer.config.out_channels // 2 == latent_channels: flow_pred = flow_pred.chunk(2, dim=1)[0] - + # Convert flow prediction to x0 prediction # Need to rearrange dimensions: b c f h w -> b f c h w for conversion flow_pred_bfchw = rearrange(flow_pred, "b c f h w -> b f c h w") local_latent_bfchw = rearrange(local_latent, "b c f h w -> b f c h w") - + # Convert to x0 (flatten batch and frames for conversion) pred_x0_flat = self._convert_flow_pred_to_x0( flow_pred=flow_pred_bfchw.flatten(0, 1), @@ -1153,19 +1153,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ) pred_x0_bfchw = pred_x0_flat.unflatten(0, (flow_pred_bfchw.shape[0], flow_pred_bfchw.shape[1])) pred_x0 = rearrange(pred_x0_bfchw, "b f c h w -> b c f h w") - + # Denoise: x_t -> x_0, then add noise for next timestep if step_idx < len(denoising_step_list) - 1: # Not the last step, add noise for next timestep next_timestep = denoising_step_list[step_idx + 1] next_t = torch.tensor([next_timestep], device=device, dtype=torch.float32) - + # Rearrange for scale_noise: b c f h w -> b f c h w pred_x0_for_noise = rearrange(pred_x0, "b c f h w -> b f c h w") noise = randn_tensor( pred_x0_for_noise.shape, generator=generator, device=device, dtype=pred_x0_for_noise.dtype ) - + # Add noise using scale_noise: flatten batch and frames # scale_noise formula: sigma * noise + (1 - sigma) * sample local_latent_flat = self.scheduler.scale_noise( @@ -1178,19 +1178,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): else: # Last step, use x_0 as final result local_latent = pred_x0 - + progress_bar.update() - + if XLA_AVAILABLE: xm.mark_step() - + # Store the denoised chunk output[:, :, start_f:end_f] = local_latent - + # Update KV cache for this chunk by running forward pass at timestep 0 latent_for_cache = output[:, :, start_f:end_f] timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long) - + cache_kwargs = { "encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype), "encoder_attention_mask": prompt_attention_mask, @@ -1199,25 +1199,25 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): "save_kv_cache": True, # Enable saving during cache update "kv_cache": chunk_kv_cache, } - + if self.attention_kwargs is not None: cache_kwargs["attention_kwargs"] = self.attention_kwargs - + # Forward pass to update KV cache cache_output = self.transformer( latent_for_cache.to(dtype=transformer_dtype), **cache_kwargs, ) - + # Extract updated KV cache if returned if isinstance(cache_output, tuple) and len(cache_output) == 2: _, updated_kv_cache = cache_output if updated_kv_cache is not None: kv_cache[chunk_idx] = updated_kv_cache - + if XLA_AVAILABLE: xm.mark_step() - + latents = output if output_type == "latent": diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 928f0b9774..36f8767ffe 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1353,6 +1353,21 @@ class SanaTransformer2DModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class SanaVideoCausalTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SanaVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e6cf26a125..cfd95c9190 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1697,6 +1697,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class LongSanaVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]