1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-05-12 19:37:28 +05:30
parent ce642e92da
commit c8a7617536
21 changed files with 927 additions and 754 deletions

View File

@@ -761,8 +761,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
@@ -1085,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
@@ -1102,7 +1103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -119,19 +120,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -156,25 +157,25 @@ def normalized_guidance(
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -113,13 +114,13 @@ class AutoGuidance(BaseGuidance):
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
@@ -140,9 +141,9 @@ class AutoGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@@ -157,17 +158,17 @@ class AutoGuidance(BaseGuidance):
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -74,7 +75,7 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
@@ -112,17 +113,17 @@ class ClassifierFreeGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -72,7 +73,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
@@ -102,7 +103,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@@ -117,19 +118,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -58,10 +58,10 @@ class BaseGuidance:
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
@@ -104,14 +104,14 @@ class BaseGuidance:
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
@@ -119,7 +119,7 @@ class BaseGuidance:
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
@@ -139,15 +139,15 @@ class BaseGuidance:
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
"""

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -148,14 +149,14 @@ class SkipLayerGuidance(BaseGuidance):
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
@@ -200,7 +201,7 @@ class SkipLayerGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@@ -217,31 +218,31 @@ class SkipLayerGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -141,14 +142,14 @@ class SmoothedEnergyGuidance(BaseGuidance):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
@@ -193,7 +194,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@@ -210,31 +211,31 @@ class SmoothedEnergyGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
@@ -97,24 +98,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
cond_dtype = pred_cond.dtype
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
@@ -125,9 +126,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred

View File

@@ -20,7 +20,12 @@ import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
@@ -196,15 +201,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
@@ -213,7 +218,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook):
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
@@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
@@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
@@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
in the future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
@@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
@@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query

View File

@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,

View File

@@ -703,12 +703,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_sag import StableDiffusionSAGPipeline
from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
)
from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (

View File

@@ -12,21 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
from collections import OrderedDict
from itertools import combinations
from typing import List, Optional, Union, Dict, Any
import copy
from typing import Any, Dict, List, Optional, Union
import torch
import time
from dataclasses import dataclass
from ..utils import (
is_accelerate_available,
logging,
)
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
if is_accelerate_available():
@@ -231,17 +228,18 @@ class AutoOffloadStrategy:
from .modular_pipeline_utils import ComponentSpec
import uuid
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def _get_by_collection(self, collection: str):
"""
Select components by collection name.
@@ -252,8 +250,8 @@ class ComponentsManager:
for component_id in component_ids:
selected_components[component_id] = self.components[component_id]
return selected_components
def _get_by_load_id(self, load_id: str):
"""
Select components by its load_id.
@@ -263,8 +261,8 @@ class ComponentsManager:
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
selected_components[name] = component
return selected_components
def add(self, name, component, collection: Optional[str] = None):
for comp_id, comp in self.components.items():
@@ -282,7 +280,7 @@ class ComponentsManager:
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
)
# add component to components manager
self.components[component_id] = component
@@ -293,8 +291,8 @@ class ComponentsManager:
self.collections[collection].add(component_id)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
self.enable_auto_cpu_offload(self._auto_offload_device)
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
return component_id
@@ -304,14 +302,14 @@ class ComponentsManager:
if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager")
return
self.components.pop(name)
self.added_time.pop(name)
for collection in self.collections:
if name in self.collections[collection]:
self.collections[collection].remove(name)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
@@ -341,7 +339,7 @@ class ComponentsManager:
Dictionary mapping component IDs to components,
or list of (base_name, component) tuples if as_name_component_tuples=True
"""
if collection:
if collection not in self.collections:
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
@@ -360,16 +358,16 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return component_id
if names is None:
if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
else:
return components
# Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
def matches_pattern(component_id, pattern, exact_match=False):
"""
Helper function to check if a component matches a pattern based on its base name.
@@ -380,124 +378,124 @@ class ComponentsManager:
exact_match: If True, only exact matches to base_name are considered
"""
base_name = base_names[component_id]
# Exact match with base name
if exact_match:
return pattern == base_name
# Prefix match (ends with *)
elif pattern.endswith('*'):
prefix = pattern[:-1]
return base_name.startswith(prefix)
# Contains match (starts with *)
elif pattern.startswith('*'):
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
return search in base_name
# Exact match (no wildcards)
else:
return pattern == base_name
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith('!')
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if '|' in names:
terms = names.split('|')
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *)
elif names.endswith('*'):
prefix = names[:-1]
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *)
elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:]
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else:
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
else:
return matches
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result)
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
else:
return results
else:
raise ValueError(f"Invalid type for names: {type(names)}")
@@ -558,14 +556,14 @@ class ComponentsManager:
raise ValueError(f"Component '{name}' not found in ComponentsManager")
component = self.components[name]
# Build complete info dict first
info = {
"model_id": name,
"added_time": self.added_time[name],
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
}
# Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module):
# Check for hook information
@@ -573,7 +571,7 @@ class ComponentsManager:
execution_device = None
if has_hook and hasattr(component._hf_hook, "execution_device"):
execution_device = component._hf_hook.execution_device
info.update({
"class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3),
@@ -594,8 +592,8 @@ class ComponentsManager:
if any("IPAdapter" in ptype for ptype in processor_types):
# Then get scales only from IP-Adapter processors
scales = {
k: v.scale
for k, v in processors.items()
k: v.scale
for k, v in processors.items()
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
}
if scales:
@@ -609,7 +607,7 @@ class ComponentsManager:
else:
# List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields}
return info
def __repr__(self):
@@ -622,13 +620,13 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return name
# Extract load_id if available
def get_load_id(component):
if hasattr(component, "_diffusers_load_id"):
return component._diffusers_load_id
return "N/A"
# Format device info compactly
def format_device(component, info):
if not info["has_hook"]:
@@ -637,24 +635,24 @@ class ComponentsManager:
device = str(getattr(component, 'device', 'N/A'))
exec_device = str(info['execution_device'] or 'N/A')
return f"{device}({exec_device})"
# Get all simple names to calculate width
simple_names = [get_simple_name(id) for id in self.components.keys()]
# Get max length of load_ids for models
load_ids = [
get_load_id(component)
for component in self.components.values()
get_load_id(component)
for component in self.components.values()
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Collection names
collection_names = [
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
for name in self.components.keys()
]
col_widths = {
"name": max(15, max(len(name) for name in simple_names)),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
@@ -692,7 +690,7 @@ class ComponentsManager:
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
load_id = get_load_id(component)
collection = info["collection"] or "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
@@ -712,7 +710,7 @@ class ComponentsManager:
info = self.get_model_info(name)
simple_name = get_simple_name(name)
collection = info["collection"] or "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
output += dash_line
@@ -726,9 +724,9 @@ class ComponentsManager:
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n"
output += " IP-Adapter: Enabled\n"
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
return output
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
@@ -759,13 +757,13 @@ class ComponentsManager:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
@@ -791,13 +789,13 @@ class ComponentsManager:
ValueError: If no components match or multiple components match
"""
results = self.get(name, collection, load_id)
if not results:
raise ValueError(f"No components found matching '{name}'")
if len(results) > 1:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
@@ -823,17 +821,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
if value_tuple not in value_to_keys:
value_to_keys[value_tuple] = []
value_to_keys[value_tuple].append(key)
def find_common_prefix(keys: List[str]) -> str:
"""Find the shortest common prefix among a list of dot-separated keys."""
if not keys:
return ""
if len(keys) == 1:
return keys[0]
# Split all keys into parts
key_parts = [k.split('.') for k in keys]
# Find how many initial parts are common
common_length = 0
for parts in zip(*key_parts):
@@ -841,10 +839,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
common_length += 1
else:
break
if common_length == 0:
return ""
# Return the common prefix
return '.'.join(key_parts[0][:common_length])
@@ -858,5 +856,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
summary[prefix] = value
else:
summary[""] = value # Use empty string if no common prefix
return summary

File diff suppressed because it is too large Load Diff

View File

@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import inspect
from dataclasses import dataclass, asdict, field, fields
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
import re
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
if is_torch_available():
import torch
@@ -56,50 +57,50 @@ class ComponentSpec:
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (self.name == other.name and
self.load_id == other.load_id and
return (self.name == other.name and
self.load_id == other.load_id and
self.default_creation_method == other.default_creation_method)
@classmethod
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
"""Create a ComponentSpec from a Component created by `create` method."""
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` method")
type_hint = component.__class__
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
@classmethod
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
"""Create a ComponentSpec from a load_id string."""
if load_id == "null":
raise ValueError("Cannot create ComponentSpec from null load_id")
# Decode the load_id into a dictionary of loading fields
load_fields = cls.decode_load_id(load_id)
# Create a new ComponentSpec instance with the decoded fields
return cls(name=name, **load_fields)
@classmethod
def loading_fields(cls) -> List[str]:
"""
@@ -107,8 +108,8 @@ class ComponentSpec:
(i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
@@ -118,7 +119,7 @@ class ComponentSpec:
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
@@ -139,29 +140,29 @@ class ComponentSpec:
If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not loaded from pretrained).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: add validator
def create(self, **kwargs) -> Any:
"""Create the component using the preferred creation method."""
# from_pretrained creation
if self.default_creation_method == "from_pretrained":
return self.create_from_pretrained(**kwargs)
@@ -170,17 +171,17 @@ class ComponentSpec:
return self.create_from_config(**kwargs)
else:
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError(
f"`type_hint` is required when using from_config creation method."
"`type_hint` is required when using from_config creation method."
)
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
@@ -193,24 +194,24 @@ class ComponentSpec:
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def create_from_pretrained(self, **kwargs) -> Any:
"""Create component using from_pretrained."""
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None:
try:
from diffusers import AutoModel
@@ -223,19 +224,19 @@ class ComponentSpec:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
if repo != self.repo:
self.repo = repo
for k, v in passed_loading_kwargs.items():
if v is not None:
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
@@ -254,7 +255,7 @@ class InputParam:
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
@@ -287,14 +288,14 @@ def format_inputs_short(inputs):
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
@@ -321,18 +322,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
input_parts.append(f"Required({inp.name})")
else:
input_parts.append(inp.name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}
modified_parts = []
new_output_parts = []
for out in intermediates_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
@@ -340,7 +341,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
@@ -358,18 +359,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
@@ -379,7 +380,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
@@ -387,20 +388,20 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
param_str = f"{param_indent}{param.name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
@@ -408,7 +409,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(
@@ -418,9 +419,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
@@ -466,42 +467,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
@@ -519,27 +520,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
@@ -584,9 +585,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
return output

View File

@@ -334,6 +334,7 @@ def maybe_raise_or_warn(
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
def simple_get_class_obj(library_name, class_name):
from diffusers import pipelines
is_pipeline_module = hasattr(pipelines, library_name)
if is_pipeline_module:
@@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name):
return class_obj
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):

View File

@@ -1120,7 +1120,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
automatically detect the available accelerator and use.
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1

View File

@@ -61,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
from .pipeline_stable_diffusion_xl_modular import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDecodeLatentsStep,
StableDiffusionXLDenoiseStep,
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoPipeline,
)
try:

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""Utilities to dynamically load objects from the Hub."""
import hashlib
import importlib
import inspect
import json
@@ -21,8 +22,9 @@ import os
import re
import shutil
import sys
import threading
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, ModuleType, Optional, Union
from urllib import request
from huggingface_hub import hf_hub_download, model_info
@@ -37,6 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -154,15 +157,132 @@ def check_imports(filename):
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_modular_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
force_reload (`bool`, *optional*, defaults to `False`):
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
Otherwise, the module is only reloaded if the file has changed.
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
# reload in both cases, unless the module is already imported and the hash hits
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return getattr(module, class_name)
def get_class_in_module(class_name, module_path, force_reload=False):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
if class_name is None:
return find_pipeline_class(module)
return getattr(module, class_name)
@@ -203,6 +323,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
is_modular: bool = False,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -257,7 +378,7 @@ def get_cached_module_file(
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
submodule = "local"
elif pretrained_model_name_or_path.count("/") == 0:
elif pretrained_model_name_or_path.count("/") == 0 and not is_modular:
available_versions = get_diffusers_versions()
# cut ".dev0"
latest_version = "v" + ".".join(__version__.split(".")[:3])
@@ -297,6 +418,24 @@ def get_cached_module_file(
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
elif is_modular:
try:
# Load from URL or cache if already cached
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
else:
try:
# Load from URL or cache if already cached
@@ -381,6 +520,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
is_modular: bool = False,
**kwargs,
):
"""
@@ -453,5 +593,7 @@ def get_class_from_dynamic_module(
token=token,
revision=revision,
local_files_only=local_files_only,
is_modular=is_modular,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
__import__("ipdb").set_trace()
return get_class_in_module(class_name, final_module)