1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

start folderizing the loaders.

This commit is contained in:
sayakpaul
2025-04-16 12:02:06 +05:30
parent ce1063acfa
commit 8267677a24
16 changed files with 79 additions and 48 deletions

View File

@@ -54,7 +54,7 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"LoraBaseMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -90,25 +91,21 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
from .single_file import FromOriginalModelMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
from .lora import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraBaseMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,

View File

@@ -0,0 +1,9 @@
from ...utils.import_utils import is_torch_available, is_transformers_available
if is_torch_available():
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
if is_transformers_available():
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin

View File

@@ -13,21 +13,11 @@
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ...utils import is_accelerate_available, is_torch_version, logging
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
@@ -88,9 +78,7 @@ class FluxTransformer2DLoadersMixin:
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
if low_cpu_mem_usage:
if is_accelerate_available():

View File

@@ -14,10 +14,10 @@
from contextlib import nullcontext
from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ...models.embeddings import IPAdapterTimeImageProjection
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ...utils import is_accelerate_available, is_torch_version, logging
logger = logging.get_logger(__name__)

View File

@@ -0,0 +1,24 @@
from ...utils import is_peft_available, is_torch_available, is_transformers_available
if is_torch_available():
from .lora_base import LoraBaseMixin
if is_transformers_available():
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
)

View File

@@ -0,0 +1,8 @@
from ...utils import is_torch_available, is_transformers_available
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
if is_transformers_available():
from .single_file import FromSingleFileMixin

View File

@@ -21,7 +21,7 @@ from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from typing_extensions import Self
from ..utils import deprecate, is_transformers_available, logging
from ...utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
SingleFileComponentError,
_is_legacy_scheduler_kwargs,

View File

@@ -21,9 +21,9 @@ import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ... import __version__
from ...quantizers import DiffusersAutoQuantizer
from ...utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -58,7 +58,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
from ...models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {

View File

@@ -25,8 +25,8 @@ import requests
import torch
import yaml
from ..models.modeling_utils import load_state_dict
from ..schedulers import (
from ...models.modeling_utils import load_state_dict
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
@@ -54,7 +54,7 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
from ...models.modeling_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -0,0 +1,5 @@
from ...utils import is_torch_available
if is_torch_available():
from .unet import UNet2DConditionLoadersMixin

View File

@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..models.embeddings import (
from ...models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
IPAdapterFaceIDPlusImageProjection,
@@ -30,8 +30,8 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ...utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
@@ -43,9 +43,9 @@ from ..utils import (
is_torch_version,
logging,
)
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
from ..lora import _func_optionally_disable_offloading
from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from ..utils import AttnProcsLayers
logger = logging.get_logger(__name__)
@@ -247,7 +247,7 @@ class UNet2DConditionLoadersMixin:
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ...models.attention_processor import CustomDiffusionAttnProcessor
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -451,7 +451,7 @@ class UNet2DConditionLoadersMixin:
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
```
"""
from ..models.attention_processor import (
from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -513,7 +513,7 @@ class UNet2DConditionLoadersMixin:
logger.info(f"Model weights saved in {save_path}")
def _get_custom_diffusion_state_dict(self):
from ..models.attention_processor import (
from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -759,7 +759,7 @@ class UNet2DConditionLoadersMixin:
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
from ...models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,