mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add torch_xla and from_single_file support to TextToVideoZeroPipeline (#10445)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -11,16 +11,30 @@ from torch.nn.functional import grid_sample
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
|
||||
|
||||
|
||||
class TextToVideoZeroPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
|
||||
@@ -440,6 +458,10 @@ class TextToVideoZeroPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
return latents.clone().detach()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
|
||||
|
||||
Reference in New Issue
Block a user