mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make style;
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,9 @@ from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
@@ -969,8 +969,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -1208,6 +1208,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongSanaVideoPipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
@@ -1245,7 +1246,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
LongSanaVideoPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
@@ -765,7 +764,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
attn_output, kv_cache = attn_result
|
||||
else:
|
||||
attn_output = attn_result
|
||||
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention (no cache)
|
||||
@@ -782,7 +781,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
|
||||
|
||||
# Cached conv always supports kv_cache
|
||||
ff_result = self.ff(
|
||||
norm_hidden_states,
|
||||
@@ -793,7 +792,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
ff_output, kv_cache = ff_result
|
||||
else:
|
||||
ff_output = ff_result
|
||||
|
||||
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
@@ -1248,7 +1247,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
if kv_cache is not None:
|
||||
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
|
||||
kv_cache = None
|
||||
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1269,7 +1268,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Get kv_cache for this block if available
|
||||
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
|
||||
|
||||
|
||||
block_result = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -1283,7 +1282,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=block_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
# Handle return value (could be tensor or tuple)
|
||||
if isinstance(block_result, tuple):
|
||||
hidden_states, updated_kv_cache = block_result
|
||||
@@ -1291,7 +1290,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
kv_cache[index_block] = updated_kv_cache
|
||||
else:
|
||||
hidden_states = block_result
|
||||
|
||||
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
@@ -30,7 +29,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from .transformer_sana_video import SanaVideoTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -524,7 +523,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
attn_output, kv_cache = attn_result
|
||||
else:
|
||||
attn_output = attn_result
|
||||
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention (no cache)
|
||||
@@ -541,7 +540,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
|
||||
|
||||
# Cached conv always supports kv_cache
|
||||
ff_result = self.ff(
|
||||
norm_hidden_states,
|
||||
@@ -552,7 +551,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
ff_output, kv_cache = ff_result
|
||||
else:
|
||||
ff_output = ff_result
|
||||
|
||||
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
@@ -756,7 +755,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
if kv_cache is not None:
|
||||
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
|
||||
kv_cache = None
|
||||
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -777,7 +776,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Get kv_cache for this block if available
|
||||
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
|
||||
|
||||
|
||||
block_result = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -791,7 +790,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=block_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
# Handle return value (could be tensor or tuple)
|
||||
if isinstance(block_result, tuple):
|
||||
hidden_states, updated_kv_cache = block_result
|
||||
@@ -799,7 +798,7 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
kv_cache[index_block] = updated_kv_cache
|
||||
else:
|
||||
hidden_states = block_result
|
||||
|
||||
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
|
||||
@@ -757,7 +757,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
)
|
||||
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline, LongSanaVideoPipeline
|
||||
from .sana_video import LongSanaVideoPipeline, SanaImageToVideoPipeline, SanaVideoPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -34,9 +34,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_longsana import LongSanaVideoPipeline
|
||||
from .pipeline_sana_video import SanaVideoPipeline
|
||||
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
|
||||
from .pipeline_longsana import LongSanaVideoPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
|
||||
# LongSana specific parameters
|
||||
self.base_chunk_frames = base_chunk_frames
|
||||
self.num_cached_blocks = num_cached_blocks
|
||||
@@ -680,11 +680,11 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
original_dtype = flow_pred.dtype
|
||||
flow_pred_f64 = flow_pred.double()
|
||||
xt_f64 = xt.double()
|
||||
|
||||
|
||||
# Get sigma_t from scheduler
|
||||
sigmas = self.scheduler.sigmas.double().to(flow_pred.device)
|
||||
timesteps_sched = self.scheduler.timesteps.double().to(flow_pred.device)
|
||||
|
||||
|
||||
# Find closest timestep index
|
||||
# timestep is scalar or [B], we need to match it against scheduler timesteps
|
||||
if timestep.dim() == 0:
|
||||
@@ -692,10 +692,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
timestep_f64 = timestep.double()
|
||||
timestep_id = torch.argmin((timesteps_sched.unsqueeze(0) - timestep_f64.unsqueeze(1)).abs(), dim=1)
|
||||
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
|
||||
|
||||
|
||||
# x_0 = x_t - sigma_t * flow_pred
|
||||
x0_pred = xt_f64 - sigma_t * flow_pred_f64
|
||||
|
||||
|
||||
return x0_pred.to(original_dtype)
|
||||
|
||||
def _create_autoregressive_segments(self, total_frames: int, base_chunk_frames: int) -> List[int]:
|
||||
@@ -751,16 +751,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"""
|
||||
if chunk_idx == 0:
|
||||
return kv_cache[0]
|
||||
|
||||
|
||||
cur_kv_cache = kv_cache[chunk_idx]
|
||||
for block_id in range(num_blocks):
|
||||
# Copy temporal cache from previous chunk
|
||||
cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2]
|
||||
|
||||
|
||||
# Accumulate spatial KV cache from previous chunks
|
||||
cum_vk, cum_k_sum = None, None
|
||||
start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0
|
||||
|
||||
|
||||
for i in range(start_chunk_idx, chunk_idx):
|
||||
prev = kv_cache[i][block_id]
|
||||
if prev[0] is not None and prev[1] is not None:
|
||||
@@ -770,13 +770,13 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
else:
|
||||
cum_vk += prev[0]
|
||||
cum_k_sum += prev[1]
|
||||
|
||||
|
||||
if chunk_idx > 0:
|
||||
assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed"
|
||||
|
||||
|
||||
cur_kv_cache[block_id][0] = cum_vk
|
||||
cur_kv_cache[block_id][1] = cum_k_sum
|
||||
|
||||
|
||||
return cur_kv_cache
|
||||
|
||||
def _get_num_transformer_blocks(self) -> int:
|
||||
@@ -1053,51 +1053,51 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if denoising_step_list is None:
|
||||
# Use the standard timesteps from the scheduler
|
||||
denoising_step_list = timesteps.cpu().tolist()
|
||||
|
||||
|
||||
device = latents.device
|
||||
batch_size_latents, _, total_frames, height_latent, width_latent = latents.shape
|
||||
|
||||
|
||||
# Create autoregressive segments
|
||||
chunk_indices = self._create_autoregressive_segments(total_frames, self.base_chunk_frames)
|
||||
num_chunks = len(chunk_indices) - 1
|
||||
|
||||
|
||||
# Get number of transformer blocks
|
||||
num_blocks = self._get_num_transformer_blocks()
|
||||
|
||||
|
||||
# Initialize KV cache for all chunks
|
||||
kv_cache = self._initialize_kv_cache(num_chunks, num_blocks)
|
||||
|
||||
|
||||
# Output tensor to store denoised results
|
||||
output = torch.zeros_like(latents)
|
||||
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
|
||||
# Process each chunk
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_f = chunk_indices[chunk_idx]
|
||||
end_f = chunk_indices[chunk_idx + 1]
|
||||
|
||||
|
||||
# Extract chunk latents
|
||||
local_latent = latents[:, :, start_f:end_f].clone()
|
||||
|
||||
|
||||
# Accumulate KV cache from previous chunks
|
||||
chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks)
|
||||
|
||||
|
||||
# Multi-step denoising for this chunk
|
||||
with self.progress_bar(total=len(denoising_step_list)) as progress_bar:
|
||||
for step_idx, current_timestep in enumerate(denoising_step_list):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
|
||||
# Prepare model input
|
||||
latent_model_input = (
|
||||
torch.cat([local_latent] * 2) if self.do_classifier_free_guidance else local_latent
|
||||
)
|
||||
|
||||
|
||||
# Create timestep tensor
|
||||
t = torch.tensor([current_timestep], device=device, dtype=torch.long)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
|
||||
# Forward pass through transformer with KV cache
|
||||
transformer_kwargs = {
|
||||
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
|
||||
@@ -1107,16 +1107,16 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"save_kv_cache": False, # Don't save during denoising steps
|
||||
"kv_cache": chunk_kv_cache, # Pass accumulated KV cache
|
||||
}
|
||||
|
||||
|
||||
if self.attention_kwargs is not None:
|
||||
transformer_kwargs["attention_kwargs"] = self.attention_kwargs
|
||||
|
||||
|
||||
# Predict flow
|
||||
model_output = self.transformer(
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(model_output, tuple):
|
||||
if len(model_output) == 2:
|
||||
@@ -1128,23 +1128,23 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
flow_pred = model_output[0]
|
||||
else:
|
||||
flow_pred = model_output
|
||||
|
||||
|
||||
flow_pred = flow_pred.float()
|
||||
|
||||
|
||||
# Perform guidance on flow prediction
|
||||
if self.do_classifier_free_guidance:
|
||||
flow_pred_uncond, flow_pred_text = flow_pred.chunk(2)
|
||||
flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_text - flow_pred_uncond)
|
||||
|
||||
|
||||
# Handle learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
flow_pred = flow_pred.chunk(2, dim=1)[0]
|
||||
|
||||
|
||||
# Convert flow prediction to x0 prediction
|
||||
# Need to rearrange dimensions: b c f h w -> b f c h w for conversion
|
||||
flow_pred_bfchw = rearrange(flow_pred, "b c f h w -> b f c h w")
|
||||
local_latent_bfchw = rearrange(local_latent, "b c f h w -> b f c h w")
|
||||
|
||||
|
||||
# Convert to x0 (flatten batch and frames for conversion)
|
||||
pred_x0_flat = self._convert_flow_pred_to_x0(
|
||||
flow_pred=flow_pred_bfchw.flatten(0, 1),
|
||||
@@ -1153,19 +1153,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
)
|
||||
pred_x0_bfchw = pred_x0_flat.unflatten(0, (flow_pred_bfchw.shape[0], flow_pred_bfchw.shape[1]))
|
||||
pred_x0 = rearrange(pred_x0_bfchw, "b f c h w -> b c f h w")
|
||||
|
||||
|
||||
# Denoise: x_t -> x_0, then add noise for next timestep
|
||||
if step_idx < len(denoising_step_list) - 1:
|
||||
# Not the last step, add noise for next timestep
|
||||
next_timestep = denoising_step_list[step_idx + 1]
|
||||
next_t = torch.tensor([next_timestep], device=device, dtype=torch.float32)
|
||||
|
||||
|
||||
# Rearrange for scale_noise: b c f h w -> b f c h w
|
||||
pred_x0_for_noise = rearrange(pred_x0, "b c f h w -> b f c h w")
|
||||
noise = randn_tensor(
|
||||
pred_x0_for_noise.shape, generator=generator, device=device, dtype=pred_x0_for_noise.dtype
|
||||
)
|
||||
|
||||
|
||||
# Add noise using scale_noise: flatten batch and frames
|
||||
# scale_noise formula: sigma * noise + (1 - sigma) * sample
|
||||
local_latent_flat = self.scheduler.scale_noise(
|
||||
@@ -1178,19 +1178,19 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
else:
|
||||
# Last step, use x_0 as final result
|
||||
local_latent = pred_x0
|
||||
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
|
||||
# Store the denoised chunk
|
||||
output[:, :, start_f:end_f] = local_latent
|
||||
|
||||
|
||||
# Update KV cache for this chunk by running forward pass at timestep 0
|
||||
latent_for_cache = output[:, :, start_f:end_f]
|
||||
timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long)
|
||||
|
||||
|
||||
cache_kwargs = {
|
||||
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
|
||||
"encoder_attention_mask": prompt_attention_mask,
|
||||
@@ -1199,25 +1199,25 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"save_kv_cache": True, # Enable saving during cache update
|
||||
"kv_cache": chunk_kv_cache,
|
||||
}
|
||||
|
||||
|
||||
if self.attention_kwargs is not None:
|
||||
cache_kwargs["attention_kwargs"] = self.attention_kwargs
|
||||
|
||||
|
||||
# Forward pass to update KV cache
|
||||
cache_output = self.transformer(
|
||||
latent_for_cache.to(dtype=transformer_dtype),
|
||||
**cache_kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Extract updated KV cache if returned
|
||||
if isinstance(cache_output, tuple) and len(cache_output) == 2:
|
||||
_, updated_kv_cache = cache_output
|
||||
if updated_kv_cache is not None:
|
||||
kv_cache[chunk_idx] = updated_kv_cache
|
||||
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
|
||||
latents = output
|
||||
|
||||
if output_type == "latent":
|
||||
|
||||
@@ -1353,6 +1353,21 @@ class SanaTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoCausalTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1697,6 +1697,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongSanaVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user