1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

merge part1

This commit is contained in:
yiyixuxu
2025-05-20 18:53:04 +02:00
25 changed files with 5859 additions and 4245 deletions

View File

@@ -39,6 +39,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"modular_pipelines": [],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -254,13 +255,19 @@ else:
"KarrasVePipeline",
"LDMPipeline",
"LDMSuperResolutionPipeline",
"ModularLoader",
"PNDMPipeline",
"RePaintPipeline",
"ScoreSdeVePipeline",
"StableDiffusionMixin",
]
)
_import_structure["modular_pipelines"].extend(
[
"ModularLoader",
"ComponentSpec",
"ComponentsManager",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
@@ -509,12 +516,10 @@ else:
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLModularLoader",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLAutoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
@@ -541,6 +546,24 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
]
else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoPipeline",
"StableDiffusionXLModularLoader",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -864,12 +887,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
ModularLoader,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .modular_pipelines import (
ModularLoader,
ComponentSpec,
ComponentsManager,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
@@ -1097,12 +1124,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
@@ -1127,7 +1152,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()

View File

@@ -13,14 +13,14 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedGuidance(BaseGuidance):
@@ -73,14 +73,18 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -22,7 +22,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class AutoGuidance(BaseGuidance):
@@ -120,11 +120,15 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,14 +13,14 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeGuidance(BaseGuidance):
@@ -75,11 +75,15 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,14 +13,14 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeZeroStarGuidance(BaseGuidance):
@@ -73,11 +73,15 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -20,7 +20,7 @@ from ..utils import get_logger
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -171,10 +171,10 @@ class BaseGuidance:
Returns:
`BlockState`: The prepared batch of data.
"""
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
data_batch = {}
for key, value in input_fields.items():
try:
@@ -186,7 +186,7 @@ class BaseGuidance:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -22,7 +22,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class SkipLayerGuidance(BaseGuidance):
@@ -156,7 +156,11 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -168,7 +172,7 @@ class SkipLayerGuidance(BaseGuidance):
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -22,7 +22,7 @@ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig,
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class SmoothedEnergyGuidance(BaseGuidance):
@@ -149,7 +149,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -161,7 +165,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,14 +13,14 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState
class TangentialClassifierFreeGuidance(BaseGuidance):
@@ -62,11 +62,15 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -30,6 +30,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
r"""

View File

@@ -0,0 +1,82 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineMixin",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
"ModularLoader",
"PipelineState",
"BlockState",
]
_import_structure["modular_pipeline_utils"] = [
"ComponentSpec",
"ConfigSpec",
"InputParam",
"OutputParam",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineMixin,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
from .components_manager import ComponentsManager
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -29,6 +29,9 @@ from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
import uuid
if is_accelerate_available():
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
from accelerate.state import PartialState
@@ -231,8 +234,6 @@ class AutoOffloadStrategy:
from .modular_pipeline_utils import ComponentSpec
import uuid
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
@@ -242,78 +243,122 @@ class ComponentsManager:
self._auto_offload_enabled = False
def _get_by_collection(self, collection: str):
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
"""
Select components by collection name.
Lookup component_ids by name, collection, or load_id.
"""
selected_components = {}
if collection in self.collections:
component_ids = self.collections[collection]
for component_id in component_ids:
selected_components[component_id] = self.components[component_id]
return selected_components
if components is None:
components = self.components
if name:
ids_by_name = set()
for component_id, component in components.items():
comp_name = self._id_to_name(component_id)
if comp_name == name:
ids_by_name.add(component_id)
else:
ids_by_name = set(components.keys())
if collection:
ids_by_collection = set()
for component_id, component in components.items():
if component_id in self.collections[collection]:
ids_by_collection.add(component_id)
else:
ids_by_collection = set(components.keys())
if load_id:
ids_by_load_id = set()
for name, component in components.items():
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
ids_by_load_id.add(name)
else:
ids_by_load_id = set(components.keys())
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
return ids
def _get_by_load_id(self, load_id: str):
"""
Select components by its load_id.
"""
selected_components = {}
for name, component in self.components.items():
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
selected_components[name] = component
return selected_components
@staticmethod
def _id_to_name(component_id: str):
return "_".join(component_id.split("_")[:-1])
def add(self, name, component, collection: Optional[str] = None):
for comp_id, comp in self.components.items():
if comp == component:
logger.warning(f"Component '{name}' already exists in ComponentsManager")
return comp_id
component_id = f"{name}_{uuid.uuid4()}"
# check for duplicated components
for comp_id, comp in self.components.items():
if comp == component:
comp_name = self._id_to_name(comp_id)
if comp_name == name:
logger.warning(
f"component '{name}' already exists as '{comp_id}'"
)
component_id = comp_id
break
else:
logger.warning(
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# check for duplicated load_id and warn (we do not delete for you)
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id)
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
if components_with_same_load_id:
existing = ", ".join(components_with_same_load_id.keys())
existing = ", ".join(components_with_same_load_id)
logger.warning(
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# add component to components manager
self.components[component_id] = component
self.added_time[component_id] = time.time()
if collection:
if collection not in self.collections:
self.collections[collection] = set()
self.collections[collection].add(component_id)
if not component_id in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
self.remove(comp_id)
self.collections[collection].add(component_id)
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
else:
logger.info(f"Added component '{name}' as '{component_id}'")
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
return component_id
def remove(self, name: Union[str, List[str]]):
def remove(self, component_id: str = None):
if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager")
if component_id not in self.components:
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
return
self.components.pop(name)
self.added_time.pop(name)
component = self.components.pop(component_id)
self.added_time.pop(component_id)
for collection in self.collections:
if name in self.collections[collection]:
self.collections[collection].remove(name)
if component_id in self.collections[collection]:
self.collections[collection].remove(component_id)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
else:
if isinstance(component, torch.nn.Module):
component.to("cpu")
del component
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
as_name_component_tuples: bool = False):
@@ -342,16 +387,8 @@ class ComponentsManager:
or list of (base_name, component) tuples if as_name_component_tuples=True
"""
if collection:
if collection not in self.collections:
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
return [] if as_name_component_tuples else {}
components = self._get_by_collection(collection)
else:
components = self.components
if load_id:
components = self._get_by_load_id(load_id)
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
# Helper to extract base name from component_id
def get_base_name(component_id):
@@ -541,11 +578,11 @@ class ComponentsManager:
self._auto_offload_enabled = False
# YiYi TODO: add quantization info
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
Args:
name: Name of the component to get info for
component_id: Name of the component to get info for
fields: Optional field(s) to return. Can be a string for single field or list of fields.
If None, returns all fields.
@@ -554,16 +591,16 @@ class ComponentsManager:
If fields is specified, returns only those fields.
If a single field is requested as string, returns just that field's value.
"""
if name not in self.components:
raise ValueError(f"Component '{name}' not found in ComponentsManager")
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[name]
component = self.components[component_id]
# Build complete info dict first
info = {
"model_id": name,
"added_time": self.added_time[name],
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
"model_id": component_id,
"added_time": self.added_time[component_id],
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
}
# Additional info for torch.nn.Module components
@@ -649,11 +686,19 @@ class ComponentsManager:
]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Collection names
collection_names = [
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
for name in self.components.keys()
]
# Get all collections for each component
component_collections = {}
for name in self.components.keys():
component_collections[name] = []
for coll, comps in self.collections.items():
if name in comps:
component_collections[name].append(coll)
if not component_collections[name]:
component_collections[name] = ["N/A"]
# Find the maximum collection name length
all_collections = [coll for colls in component_collections.values() for coll in colls]
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
col_widths = {
"name": max(15, max(len(name) for name in simple_names)),
@@ -662,7 +707,7 @@ class ComponentsManager:
"dtype": 15,
"size": 10,
"load_id": max_load_id_len,
"collection": max(10, max(len(str(c)) for c in collection_names))
"collection": max_collection_len
}
# Create the header lines
@@ -691,11 +736,21 @@ class ComponentsManager:
device_str = format_device(component, info)
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
load_id = get_load_id(component)
collection = info["collection"] or "N/A"
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
output += dash_line
# Other components section
@@ -711,9 +766,17 @@ class ComponentsManager:
for name, component in others.items():
info = self.get_model_info(name)
simple_name = get_simple_name(name)
collection = info["collection"] or "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
output += dash_line
# Add additional component info
@@ -775,7 +838,7 @@ class ComponentsManager:
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
"""
Get a single component by name. Raises an error if multiple components match or none are found.
@@ -790,6 +853,15 @@ class ComponentsManager:
Raises:
ValueError: If no components match or multiple components match
"""
# if component_id is provided, return the component
if component_id is not None and (name is not None or collection is not None or load_id is not None):
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
elif component_id is not None:
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
return self.components[component_id]
results = self.get(name, collection, load_id)
if not results:

View File

@@ -71,34 +71,31 @@ class ComponentSpec:
self.default_creation_method == other.default_creation_method)
@classmethod
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
"""Create a ComponentSpec from a Component created by `create` method."""
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` method")
raise ValueError("Component is not created by `create` or `load` method")
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
"We currently only support creating ComponentSpec from a component with "
"created with `ComponentSpec.load` method"
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
)
type_hint = component.__class__
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
if isinstance(component, ConfigMixin):
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
@classmethod
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
"""Create a ComponentSpec from a load_id string."""
if load_id == "null":
raise ValueError("Cannot create ComponentSpec from null load_id")
# Decode the load_id into a dictionary of loading fields
load_fields = cls.decode_load_id(load_id)
# Create a new ComponentSpec instance with the decoded fields
return cls(name=name, **load_fields)
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
@classmethod
def loading_fields(cls) -> List[str]:
@@ -137,7 +134,7 @@ class ComponentSpec:
"revision": "revision"
}
If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not loaded from pretrained).
Returns None if load_id is "null" (indicating component not created with `load` method).
"""
# Get all loading fields in order
@@ -158,20 +155,12 @@ class ComponentSpec:
return result
# YiYi TODO: add validator
def create(self, **kwargs) -> Any:
"""Create the component using the preferred creation method."""
# from_pretrained creation
if self.default_creation_method == "from_pretrained":
return self.create_from_pretrained(**kwargs)
elif self.default_creation_method == "from_config":
# from_config creation
return self.create_from_config(**kwargs)
else:
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
# remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
@@ -201,34 +190,35 @@ class ComponentSpec:
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def create_from_pretrained(self, **kwargs) -> Any:
"""Create component using from_pretrained."""
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}")
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
raise ValueError(f"Unable to load {self.name} using load method: {e}")
if repo != self.repo:
self.repo = repo
for k, v in passed_loading_kwargs.items():
if v is not None:
setattr(self, k, v)
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@@ -241,14 +231,22 @@ class ConfigSpec:
name: str
default: Any
description: Optional[str] = None
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
# however some fields are not relevant for intermediates_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
# -> should we use different class for inputs and intermediates_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -260,6 +258,7 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
@@ -320,7 +319,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
if inp.name in required_intermediates_inputs:
input_parts.append(f"Required({inp.name})")
else:
input_parts.append(inp.name)
if inp.name is None and inp.kwargs_type is not None:
inp_name = "*_" + inp.kwargs_type
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}

View File

@@ -0,0 +1,51 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"]
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
from .modular_loader import StableDiffusionXLModularLoader
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
from .decoders import StableDiffusionXLAutoDecodeStep
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,215 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...configuration_utils import FrozenDict
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ..modular_pipeline import (
AutoPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components
@staticmethod
def upcast_vae(components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
# apply watermark if available
if hasattr(components, "watermark") and components.watermark is not None:
block_state.images = components.watermark.apply_watermark(block_state.images)
block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \
"only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"),
InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.")
]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images]
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
block_names = ["decode", "mask_overlay"]
@property
def description(self):
return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \
"This is a sequential pipeline blocks:\n" + \
" - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \
" - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
block_names = ["inpaint", "non-inpaint"]
block_trigger_inputs = ["padding_mask_crop", None]
@property
def description(self):
return "Decode step that decode the denoised latents into images outputs.\n" + \
"This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \
" - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \
" - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,858 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor, unwrap_module
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...configuration_utils import FrozenDict
from transformers import (
CLIPTextModel,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...schedulers import EulerDiscreteScheduler
from ...guiders import ClassifierFreeGuidance
from .modular_loader import StableDiffusionXLModularLoader
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter"
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
@staticmethod
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = components.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = components.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = components.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return(
"Text Encoder step that generate text_embeddings to guide the image generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("clip_skip"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"),
OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"),
OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"),
OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2]
text_encoders = (
[components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None
)
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Vae Encoder step that encode the input image into a latent representation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return (
"Vae encoder step that prepares the image and mask for the inpainting process"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop)
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.add_block_state(state, block_state)
return components, state
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
# Encode
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return "Vae encoder step that encode the image inputs into their latent representations.\n" + \
"This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \
" - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \
" - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."

View File

@@ -0,0 +1,174 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...image_processor import PipelineImageInput
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularLoader
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
# YiYi Notes: model specific components:
## (1) it should inherit from ModularLoader
## (2) acts like a container that holds components and configs
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularLoader(
ModularLoader,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
@property
def default_sample_size(self):
default_sample_size = 128
if hasattr(self, "unet") and self.unet is not None:
default_sample_size = self.unet.config.sample_size
return default_sample_size
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_unet(self):
num_channels_unet = 4
if hasattr(self, "unet") and self.unet is not None:
num_channels_unet = self.unet.config.in_channels
return num_channels_unet
@property
def num_channels_latents(self):
num_channels_latents = 4
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images")
}
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
}

View File

@@ -0,0 +1,128 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
# Import all the necessary block classes
from .denoise import (
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLControlNetDenoiseStep
)
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep
)
from .encoders import (
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLVaeEncoderStep,
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintDecodeStep,
StableDiffusionXLAutoDecodeStep
)
# YiYi notes: comment out for now, work on this later
# block mapping
TEXT2IMAGE_BLOCKS = OrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep)
])
IMAGE2IMAGE_BLOCKS = OrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep)
])
INPAINT_BLOCKS = OrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLInpaintDecodeStep)
])
CONTROLNET_BLOCKS = OrderedDict([
("controlnet_input", StableDiffusionXLControlNetInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
CONTROLNET_UNION_BLOCKS = OrderedDict([
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
IP_ADAPTER_BLOCKS = OrderedDict([
("ip_adapter", StableDiffusionXLIPAdapterStep),
])
AUTO_BLOCKS = OrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep)
])
AUTO_CORE_BLOCKS = OrderedDict([
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
])
SDXL_SUPPORTED_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"controlnet_union": CONTROLNET_UNION_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS
}

View File

@@ -0,0 +1,43 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .denoise import StableDiffusionXLAutoDenoiseStep
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \
"- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \
"- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \
"- to run the controlnet workflow, you need to provide `control_image`\n" + \
"- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
"- for text-to-image generation, all you need to provide is `prompt`"

View File

@@ -47,7 +47,6 @@ else:
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["modular_pipeline"] = ["ModularLoader"]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -330,8 +329,6 @@ else:
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLModularLoader",
"StableDiffusionXLAutoPipeline",
]
)
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
@@ -481,7 +478,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .modular_pipeline import ModularLoader
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
@@ -706,9 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
)
from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (

View File

@@ -29,18 +29,6 @@ else:
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
"StableDiffusionXLControlNetDenoiseStep",
"StableDiffusionXLDecodeLatentsStep",
"StableDiffusionXLDenoiseStep",
"StableDiffusionXLInputStep",
"StableDiffusionXLModularLoader",
"StableDiffusionXLPrepareAdditionalConditioningStep",
"StableDiffusionXLPrepareLatentsStep",
"StableDiffusionXLSetTimestepsStep",
"StableDiffusionXLTextEncoderStep",
"StableDiffusionXLAutoPipeline",
]
if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
@@ -60,18 +48,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
from .pipeline_stable_diffusion_xl_modular import (
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDecodeLatentsStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLModularLoader,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoPipeline,
)
try:
if not (is_transformers_available() and is_flax_available()):