1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make style;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
This commit is contained in:
junsong
2025-11-26 07:44:28 -08:00
parent 800c3cc28d
commit 5faf4e93f7
9 changed files with 95 additions and 67 deletions

View File

@@ -15,9 +15,9 @@ from diffusers import (
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaVideoCausalTransformer3DModel,
SanaVideoPipeline,
SanaVideoTransformer3DModel,
SanaVideoCausalTransformer3DModel,
UniPCMultistepScheduler,
)
from diffusers.utils.import_utils import is_accelerate_available

View File

@@ -969,8 +969,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
SanaVideoCausalTransformer3DModel,
SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -1208,6 +1208,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LongSanaVideoPipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
@@ -1245,7 +1246,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
SanaVideoPipeline,
LongSanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
@@ -765,7 +764,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
attn_output, kv_cache = attn_result
else:
attn_output = attn_result
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention (no cache)
@@ -782,7 +781,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
# Cached conv always supports kv_cache
ff_result = self.ff(
norm_hidden_states,
@@ -793,7 +792,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
ff_output, kv_cache = ff_result
else:
ff_output = ff_result
ff_output = ff_output.flatten(1, 3)
hidden_states = hidden_states + gate_mlp * ff_output
@@ -1248,7 +1247,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
if kv_cache is not None:
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
kv_cache = None
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = self._gradient_checkpointing_func(
block,
@@ -1269,7 +1268,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
for index_block, block in enumerate(self.transformer_blocks):
# Get kv_cache for this block if available
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
block_result = block(
hidden_states,
attention_mask,
@@ -1283,7 +1282,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
save_kv_cache=save_kv_cache,
kv_cache=block_kv_cache,
)
# Handle return value (could be tensor or tuple)
if isinstance(block_result, tuple):
hidden_states, updated_kv_cache = block_result
@@ -1291,7 +1290,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
kv_cache[index_block] = updated_kv_cache
else:
hidden_states = block_result
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
@@ -30,7 +29,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from .transformer_sana_video import SanaVideoTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -524,7 +523,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
attn_output, kv_cache = attn_result
else:
attn_output = attn_result
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention (no cache)
@@ -541,7 +540,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
# Cached conv always supports kv_cache
ff_result = self.ff(
norm_hidden_states,
@@ -552,7 +551,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
ff_output, kv_cache = ff_result
else:
ff_output = ff_result
ff_output = ff_output.flatten(1, 3)
hidden_states = hidden_states + gate_mlp * ff_output
@@ -756,7 +755,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
if kv_cache is not None:
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
kv_cache = None
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = self._gradient_checkpointing_func(
block,
@@ -777,7 +776,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
for index_block, block in enumerate(self.transformer_blocks):
# Get kv_cache for this block if available
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
block_result = block(
hidden_states,
attention_mask,
@@ -791,7 +790,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
save_kv_cache=save_kv_cache,
kv_cache=block_kv_cache,
)
# Handle return value (could be tensor or tuple)
if isinstance(block_result, tuple):
hidden_states, updated_kv_cache = block_result
@@ -799,7 +798,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
kv_cache[index_block] = updated_kv_cache
else:
hidden_states = block_result
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]

View File

@@ -757,7 +757,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
)
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline, LongSanaVideoPipeline
from .sana_video import LongSanaVideoPipeline, SanaImageToVideoPipeline, SanaVideoPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

View File

@@ -34,9 +34,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_longsana import LongSanaVideoPipeline
from .pipeline_sana_video import SanaVideoPipeline
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
from .pipeline_longsana import LongSanaVideoPipeline
else:
import sys

View File

@@ -209,7 +209,7 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self.vae_scale_factor = self.vae_scale_factor_spatial
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# LongSana specific parameters
self.base_chunk_frames = base_chunk_frames
self.num_cached_blocks = num_cached_blocks
@@ -680,11 +680,11 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
original_dtype = flow_pred.dtype
flow_pred_f64 = flow_pred.double()
xt_f64 = xt.double()
# Get sigma_t from scheduler
sigmas = self.scheduler.sigmas.double().to(flow_pred.device)
timesteps_sched = self.scheduler.timesteps.double().to(flow_pred.device)
# Find closest timestep index
# timestep is scalar or [B], we need to match it against scheduler timesteps
if timestep.dim() == 0:
@@ -692,10 +692,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
timestep_f64 = timestep.double()
timestep_id = torch.argmin((timesteps_sched.unsqueeze(0) - timestep_f64.unsqueeze(1)).abs(), dim=1)
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
# x_0 = x_t - sigma_t * flow_pred
x0_pred = xt_f64 - sigma_t * flow_pred_f64
return x0_pred.to(original_dtype)
def _create_autoregressive_segments(self, total_frames: int, base_chunk_frames: int) -> List[int]:
@@ -751,16 +751,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
if chunk_idx == 0:
return kv_cache[0]
cur_kv_cache = kv_cache[chunk_idx]
for block_id in range(num_blocks):
# Copy temporal cache from previous chunk
cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2]
# Accumulate spatial KV cache from previous chunks
cum_vk, cum_k_sum = None, None
start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0
for i in range(start_chunk_idx, chunk_idx):
prev = kv_cache[i][block_id]
if prev[0] is not None and prev[1] is not None:
@@ -770,13 +770,13 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
else:
cum_vk += prev[0]
cum_k_sum += prev[1]
if chunk_idx > 0:
assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed"
cur_kv_cache[block_id][0] = cum_vk
cur_kv_cache[block_id][1] = cum_k_sum
return cur_kv_cache
def _get_num_transformer_blocks(self) -> int:
@@ -1053,51 +1053,51 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if denoising_step_list is None:
# Use the standard timesteps from the scheduler
denoising_step_list = timesteps.cpu().tolist()
device = latents.device
batch_size_latents, _, total_frames, height_latent, width_latent = latents.shape
# Create autoregressive segments
chunk_indices = self._create_autoregressive_segments(total_frames, self.base_chunk_frames)
num_chunks = len(chunk_indices) - 1
# Get number of transformer blocks
num_blocks = self._get_num_transformer_blocks()
# Initialize KV cache for all chunks
kv_cache = self._initialize_kv_cache(num_chunks, num_blocks)
# Output tensor to store denoised results
output = torch.zeros_like(latents)
transformer_dtype = self.transformer.dtype
# Process each chunk
for chunk_idx in range(num_chunks):
start_f = chunk_indices[chunk_idx]
end_f = chunk_indices[chunk_idx + 1]
# Extract chunk latents
local_latent = latents[:, :, start_f:end_f].clone()
# Accumulate KV cache from previous chunks
chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks)
# Multi-step denoising for this chunk
with self.progress_bar(total=len(denoising_step_list)) as progress_bar:
for step_idx, current_timestep in enumerate(denoising_step_list):
if self.interrupt:
continue
# Prepare model input
latent_model_input = (
torch.cat([local_latent] * 2) if self.do_classifier_free_guidance else local_latent
)
# Create timestep tensor
t = torch.tensor([current_timestep], device=device, dtype=torch.long)
timestep = t.expand(latent_model_input.shape[0])
# Forward pass through transformer with KV cache
transformer_kwargs = {
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
@@ -1107,16 +1107,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"save_kv_cache": False, # Don't save during denoising steps
"kv_cache": chunk_kv_cache, # Pass accumulated KV cache
}
if self.attention_kwargs is not None:
transformer_kwargs["attention_kwargs"] = self.attention_kwargs
# Predict flow
model_output = self.transformer(
latent_model_input.to(dtype=transformer_dtype),
**transformer_kwargs,
)
# Handle different output formats
if isinstance(model_output, tuple):
if len(model_output) == 2:
@@ -1128,23 +1128,23 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
flow_pred = model_output[0]
else:
flow_pred = model_output
flow_pred = flow_pred.float()
# Perform guidance on flow prediction
if self.do_classifier_free_guidance:
flow_pred_uncond, flow_pred_text = flow_pred.chunk(2)
flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_text - flow_pred_uncond)
# Handle learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
flow_pred = flow_pred.chunk(2, dim=1)[0]
# Convert flow prediction to x0 prediction
# Need to rearrange dimensions: b c f h w -> b f c h w for conversion
flow_pred_bfchw = rearrange(flow_pred, "b c f h w -> b f c h w")
local_latent_bfchw = rearrange(local_latent, "b c f h w -> b f c h w")
# Convert to x0 (flatten batch and frames for conversion)
pred_x0_flat = self._convert_flow_pred_to_x0(
flow_pred=flow_pred_bfchw.flatten(0, 1),
@@ -1153,19 +1153,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
)
pred_x0_bfchw = pred_x0_flat.unflatten(0, (flow_pred_bfchw.shape[0], flow_pred_bfchw.shape[1]))
pred_x0 = rearrange(pred_x0_bfchw, "b f c h w -> b c f h w")
# Denoise: x_t -> x_0, then add noise for next timestep
if step_idx < len(denoising_step_list) - 1:
# Not the last step, add noise for next timestep
next_timestep = denoising_step_list[step_idx + 1]
next_t = torch.tensor([next_timestep], device=device, dtype=torch.float32)
# Rearrange for scale_noise: b c f h w -> b f c h w
pred_x0_for_noise = rearrange(pred_x0, "b c f h w -> b f c h w")
noise = randn_tensor(
pred_x0_for_noise.shape, generator=generator, device=device, dtype=pred_x0_for_noise.dtype
)
# Add noise using scale_noise: flatten batch and frames
# scale_noise formula: sigma * noise + (1 - sigma) * sample
local_latent_flat = self.scheduler.scale_noise(
@@ -1178,19 +1178,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
else:
# Last step, use x_0 as final result
local_latent = pred_x0
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# Store the denoised chunk
output[:, :, start_f:end_f] = local_latent
# Update KV cache for this chunk by running forward pass at timestep 0
latent_for_cache = output[:, :, start_f:end_f]
timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long)
cache_kwargs = {
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
"encoder_attention_mask": prompt_attention_mask,
@@ -1199,25 +1199,25 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"save_kv_cache": True, # Enable saving during cache update
"kv_cache": chunk_kv_cache,
}
if self.attention_kwargs is not None:
cache_kwargs["attention_kwargs"] = self.attention_kwargs
# Forward pass to update KV cache
cache_output = self.transformer(
latent_for_cache.to(dtype=transformer_dtype),
**cache_kwargs,
)
# Extract updated KV cache if returned
if isinstance(cache_output, tuple) and len(cache_output) == 2:
_, updated_kv_cache = cache_output
if updated_kv_cache is not None:
kv_cache[chunk_idx] = updated_kv_cache
if XLA_AVAILABLE:
xm.mark_step()
latents = output
if output_type == "latent":

View File

@@ -1353,6 +1353,21 @@ class SanaTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class SanaVideoCausalTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SanaVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1697,6 +1697,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LongSanaVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]