1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/src/diffusers/utils/import_utils.py
dg845 c10bdd9b73 Add LTX 2.0 Video Pipelines (#12915)
* Initial LTX 2.0 transformer implementation

* Add tests for LTX 2 transformer model

* Get LTX 2 transformer tests working

* Rename LTX 2 compile test class to have LTX2

* Remove RoPE debug print statements

* Get LTX 2 transformer compile tests passing

* Fix LTX 2 transformer shape errors

* Initial script to convert LTX 2 transformer to diffusers

* Add more LTX 2 transformer audio arguments

* Allow LTX 2 transformer to be loaded from local path for conversion

* Improve dummy inputs and add test for LTX 2 transformer consistency

* Fix LTX 2 transformer bugs so consistency test passes

* Initial implementation of LTX 2.0 video VAE

* Explicitly specify temporal and spatial VAE scale factors when converting

* Add initial LTX 2.0 video VAE tests

* Add initial LTX 2.0 video VAE tests (part 2)

* Get diffusers implementation on par with official LTX 2.0 video VAE implementation

* Initial LTX 2.0 vocoder implementation

* Use RMSNorm implementation closer to original for LTX 2.0 video VAE

* start audio decoder.

* init registration.

* up

* simplify and clean up

* up

* Initial LTX 2.0 text encoder implementation

* Rough initial LTX 2.0 pipeline implementation

* up

* up

* up

* up

* Add imports for LTX 2.0 Audio VAE

* Conversion script for LTX 2.0 Audio VAE Decoder

* Add Audio VAE logic to T2V pipeline

* Duplicate scheduler for audio latents

* Support num_videos_per_prompt for prompt embeddings

* LTX 2.0 scheduler and full pipeline conversion

* Add script to test full LTX2Pipeline T2V inference

* Fix pipeline return bugs

* Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__

* Fix more bugs in LTX2Pipeline.__call__

* Improve CPU offload support

* Fix pipeline audio VAE decoding dtype bug

* Fix video shape error in full pipeline test script

* Get LTX 2 T2V pipeline to produce reasonable outputs

* Make LTX 2.0 scheduler more consistent with original code

* Fix typo when applying scheduler fix in T2V inference script

* Refactor Audio VAE to be simpler and remove helpers (#7)

* remove resolve causality axes stuff.

* remove a bunch of helpers.

* remove adjust output shape helper.

* remove the use of audiolatentshape.

* move normalization and patchify out of pipeline.

* fix

* up

* up

* Remove unpatchify and patchify ops before audio latents denormalization (#9)

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Add support for I2V (#8)

* start i2v.

* up

* up

* up

* up

* up

* remove uniform strategy code.

* remove unneeded code.

* Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11)

* test i2v.

* Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>

* up (#19)

* address initial feedback from lightricks team (#16)

* cross_attn_timestep_scale_multiplier to 1000

* implement split rope type.

* up

* propagate rope_type to rope embed classes as well.

* up

* When using split RoPE, make sure that the output dtype is same as input dtype

* Fix apply split RoPE shape error when reshaping x to 4D

* Add export_utils file for exporting LTX 2.0 videos with audio

* Tests for T2V and I2V (#6)

* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* up

* use export util funcs.

* Point original checkpoint to LTX 2.0 official checkpoint

* Allow the I2V pipeline to accept image URLs

* make style and make quality

* remove function map.

* remove args.

* update docs.

* update doc entries.

* disable ltx2_consistency test

* Simplify LTX 2 RoPE forward by removing coords is None logic

* make style and make quality

* Support LTX 2.0 audio VAE encoder

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Remove print statement in audio VAE

* up

* Fix bug when calculating audio RoPE coords

* Ltx 2 latent upsample pipeline (#12922)

* Initial implementation of LTX 2.0 latent upsampling pipeline

* Add new LTX 2.0 spatial latent upsampler logic

* Add test script for LTX 2.0 latent upsampling

* Add option to enable VAE tiling in upsampling test script

* Get latent upsampler working with video latents

* Fix typo in BlurDownsample

* Add latent upsample pipeline docstring and example

* Remove deprecated pipeline VAE slicing/tiling methods

* make style and make quality

* When returning latents, return unpacked and denormalized latents for T2V and I2V

* Add model_cpu_offload_seq for latent upsampling pipeline

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* Fix latent upsampler filename in LTX 2 conversion script

* Add latent upsample pipeline to LTX 2 docs

* Add dummy objects for LTX 2 latent upsample pipeline

* Set default FPS to official LTX 2 ckpt default of 24.0

* Set default CFG scale to official LTX 2 ckpt default of 4.0

* Update LTX 2 pipeline example docstrings

* make style and make quality

* Remove LTX 2 test scripts

* Fix LTX 2 upsample pipeline example docstring

* Add logic to convert and save a LTX 2 upsampling pipeline

* Document LTX2VideoTransformer3DModel forward pass

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2026-01-07 21:24:27 -08:00

1030 lines
34 KiB
Python

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Import utilities: Utilities related to imports and our lazy inits.
"""
import importlib.util
import inspect
import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging.version import Version, parse
from . import logging
# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
try:
_package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls
except Exception:
_package_map = None
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper()
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
global _package_map
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
if _package_map is None:
_package_map = defaultdict(list)
try:
# Fallback for Python < 3.10
for dist in importlib_metadata.distributions():
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
# Infer top-level package names from file structure
_inferred_opt_names = {
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
} - {None}
_top_level_inferred = filter(lambda name: "." not in name, _inferred_opt_names)
for pkg in _top_level_declared or _top_level_inferred:
_package_map[pkg].append(dist.metadata["Name"])
except Exception as _:
pass
try:
if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]:
if len(_package_map[pkg_name]) > 1:
logger.warning(
f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}"
)
pkg_name = _package_map[pkg_name][0]
pkg_version = importlib_metadata.version(pkg_name)
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
except (ImportError, importlib_metadata.PackageNotFoundError):
pkg_exists = False
return pkg_exists, pkg_version
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch")
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
_torch_version = "N/A"
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
else:
logger.info("Disabling Safetensors because USE_SAFETENSORS is set")
_safetensors_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
candidates = (
"onnxruntime",
"onnxruntime-cann",
"onnxruntime-directml",
"ort_nightly_directml",
"onnxruntime-gpu",
"ort_nightly_gpu",
"onnxruntime-migraphx",
"onnxruntime-openvino",
"onnxruntime-qnn",
"onnxruntime-rocm",
"onnxruntime-training",
"onnxruntime-vitisai",
)
_onnxruntime_version = None
# For the metadata, we have to look for both onnxruntime and onnxruntime-x
for pkg in candidates:
try:
_onnxruntime_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_onnx_available = _onnxruntime_version is not None
if _onnx_available:
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
try:
candidates = (
"opencv-python",
"opencv-contrib-python",
"opencv-python-headless",
"opencv-contrib-python-headless",
)
_opencv_version = None
for pkg in candidates:
try:
_opencv_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_opencv_available = _opencv_version is not None
if _opencv_available:
logger.debug(f"Successfully imported cv2 version {_opencv_version}")
except importlib_metadata.PackageNotFoundError:
_opencv_available = False
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
# importlib metadata under different name
_bs4_version = importlib_metadata.version("beautifulsoup4")
logger.debug(f"Successfully imported ftfy version {_bs4_version}")
except importlib_metadata.PackageNotFoundError:
_bs4_available = False
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
_wandb_available, _wandb_version = _is_package_available("wandb")
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
_compel_available, _compel_version = _is_package_available("compel")
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
_peft_available, _peft_version = _is_package_available("peft")
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
_timm_available, _timm_version = _is_package_available("timm")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_imageio_available, _imageio_version = _is_package_available("imageio")
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
_scipy_available, _scipy_version = _is_package_available("scipy")
_librosa_available, _librosa_version = _is_package_available("librosa")
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
_xformers_available, _xformers_version = _is_package_available("xformers")
_gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
def is_torch_available():
return _torch_available
def is_torch_xla_available():
return _torch_xla_available
def is_torch_npu_available():
return _torch_npu_available
def is_torch_mlu_available():
return _torch_mlu_available
def is_flax_available():
return _flax_available
def is_transformers_available():
return _transformers_available
def is_inflect_available():
return _inflect_available
def is_unidecode_available():
return _unidecode_available
def is_onnx_available():
return _onnx_available
def is_opencv_available():
return _opencv_available
def is_scipy_available():
return _scipy_available
def is_librosa_available():
return _librosa_available
def is_xformers_available():
return _xformers_available
def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available
def is_note_seq_available():
return _note_seq_available
def is_wandb_available():
return _wandb_available
def is_tensorboard_available():
return _tensorboard_available
def is_compel_available():
return _compel_available
def is_ftfy_available():
return _ftfy_available
def is_bs4_available():
return _bs4_available
def is_torchsde_available():
return _torchsde_available
def is_invisible_watermark_available():
return _invisible_watermark_available
def is_peft_available():
return _peft_available
def is_torchvision_available():
return _torchvision_available
def is_matplotlib_available():
return _matplotlib_available
def is_safetensors_available():
return _safetensors_available
def is_bitsandbytes_available():
return _bitsandbytes_available
def is_google_colab():
return _is_google_colab
def is_sentencepiece_available():
return _sentencepiece_available
def is_imageio_available():
return _imageio_available
def is_gguf_available():
return _gguf_available
def is_torchao_available():
return _torchao_available
def is_optimum_quanto_available():
return _optimum_quanto_available
def is_nvidia_modelopt_available():
return _nvidia_modelopt_available
def is_timm_available():
return _timm_available
def is_pytorch_retinaface_available():
return _pytorch_retinaface_available
def is_better_profanity_available():
return _better_profanity_available
def is_nltk_available():
return _nltk_available
def is_cosmos_guardrail_available():
return _cosmos_guardrail_available
def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
def is_sageattention_available():
return _sageattention_available
def is_flash_attn_available():
return _flash_attn_available
def is_flash_attn_3_available():
return _flash_attn_3_available
def is_aiter_available():
return _aiter_available
def is_kornia_available():
return _kornia_available
def is_av_available():
return _av_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
"""
# docstyle-ignore
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""
# docstyle-ignore
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
# docstyle-ignore
ONNX_IMPORT_ERROR = """
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
install onnxruntime`
"""
# docstyle-ignore
OPENCV_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
install opencv-python`
"""
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""
# docstyle-ignore
LIBROSA_IMPORT_ERROR = """
{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the
installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment.
"""
# docstyle-ignore
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""
# docstyle-ignore
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR = """
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
# docstyle-ignore
NOTE_SEQ_IMPORT_ERROR = """
{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip
install note-seq`
"""
# docstyle-ignore
WANDB_IMPORT_ERROR = """
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
install wandb`
"""
# docstyle-ignore
TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
install tensorboard`
"""
# docstyle-ignore
COMPEL_IMPORT_ERROR = """
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
"""
# docstyle-ignore
BS4_IMPORT_ERROR = """
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
FTFY_IMPORT_ERROR = """
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
TORCHSDE_IMPORT_ERROR = """
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
"""
# docstyle-ignore
INVISIBLE_WATERMARK_IMPORT_ERROR = """
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
"""
# docstyle-ignore
PEFT_IMPORT_ERROR = """
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
"""
# docstyle-ignore
SAFETENSORS_IMPORT_ERROR = """
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
"""
# docstyle-ignore
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece`
"""
# docstyle-ignore
BITSANDBYTES_IMPORT_ERROR = """
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
"""
# docstyle-ignore
IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
# docstyle-ignore
GGUF_IMPORT_ERROR = """
{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
"""
TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
torchao`
"""
QUANTO_IMPORT_ERROR = """
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
install optimum-quanto`
"""
# docstyle-ignore
PYTORCH_RETINAFACE_IMPORT_ERROR = """
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
"""
# docstyle-ignore
BETTER_PROFANITY_IMPORT_ERROR = """
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
"""
# docstyle-ignore
NLTK_IMPORT_ERROR = """
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
"""
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))
if name in [
"VersatileDiffusionTextToImagePipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionDualGuidedPipeline",
"StableDiffusionImageVariationPipeline",
"UnCLIPPipeline",
] and is_transformers_version("<", "4.25.0"):
raise ImportError(
f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
" --upgrade transformers \n```"
)
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
"<", "4.26.0"
):
raise ImportError(
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
" --upgrade transformers \n```"
)
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattr__(cls, key):
if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]:
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Compares a library version to some requirement using a given operation.
Args:
library_or_version (`str` or `packaging.version.Version`):
A library name or a version to check.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`):
The version to compare the library version against
"""
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = parse(importlib_metadata.version(library_or_version))
return operation(library_or_version, parse(requirement_version))
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of PyTorch
"""
return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of torch_xla
"""
if not is_torch_xla_available:
return False
return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _transformers_available:
return False
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _hf_hub_available:
return False
return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _accelerate_available:
return False
return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _peft_available:
return False
return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
Compares the current bitsandbytes version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _bitsandbytes_available:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _gguf_available:
return False
return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _torchao_available:
return False
return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _k_diffusion_available:
return False
return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _optimum_quanto_available:
return False
return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _nvidia_modelopt_available:
return False
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _xformers_available:
return False
return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _sageattention_available:
return False
return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _flash_attn_available:
return False
return compare_versions(parse(_flash_attn_version), operation, version)
@cache
def is_aiter_version(operation: str, version: str):
"""
Compares the current aiter version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _aiter_available:
return False
return compare_versions(parse(_aiter_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
Args:
module (ModuleType):
Module to extract the objects from.
Returns:
dict: Dictionary of object names and corresponding values
"""
objects = {}
for name in dir(module):
if name.startswith("_"):
continue
objects[name] = getattr(module, name)
return objects
class OptionalDependencyNotAvailable(BaseException):
"""
An error indicating that an optional dependency of Diffusers was not found in the environment.
"""
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))