mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refator based on dhruv's feedbacks
This commit is contained in:
@@ -23,7 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineMixin",
|
||||
"ModularPipelineBlocks",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
@@ -53,7 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularLoader,
|
||||
ModularPipelineMixin,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
|
||||
@@ -243,8 +243,7 @@ class BlockState:
|
||||
return f"BlockState(\n{attributes}\n)"
|
||||
|
||||
|
||||
|
||||
class ModularPipelineMixin(ConfigMixin):
|
||||
class ModularPipelineBlocks(ConfigMixin):
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
"""
|
||||
@@ -305,13 +304,10 @@ class ModularPipelineMixin(ConfigMixin):
|
||||
}
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
|
||||
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
create a ModularLoader, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
|
||||
# Import components loader (it is model-specific class)
|
||||
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
loader_class = getattr(diffusers_module, loader_class_name)
|
||||
@@ -322,98 +318,12 @@ class ModularPipelineMixin(ConfigMixin):
|
||||
# Create the loader with the updated specs
|
||||
specs = component_specs + config_specs
|
||||
|
||||
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
|
||||
loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
|
||||
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
|
||||
return modular_pipeline
|
||||
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
if not hasattr(self, "loader"):
|
||||
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
|
||||
self.loader = None
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
|
||||
for expected_input_param in self.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.intermediates_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self(self.loader, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
@torch.compiler.disable
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineMixin):
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
|
||||
model_name = None
|
||||
|
||||
@@ -680,7 +590,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that automatically selects a block to run based on the inputs.
|
||||
|
||||
@@ -969,7 +879,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
expected_configs=self.expected_configs
|
||||
)
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
|
||||
"""
|
||||
@@ -1009,15 +920,24 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
|
||||
|
||||
Args:
|
||||
blocks_dict: Dictionary mapping block names to block instances
|
||||
blocks_dict: Dictionary mapping block names to block classes or instances
|
||||
|
||||
Returns:
|
||||
A new SequentialPipelineBlocks instance
|
||||
"""
|
||||
instance = cls()
|
||||
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
|
||||
instance.block_names = list(blocks_dict.keys())
|
||||
instance.blocks = blocks_dict
|
||||
|
||||
# Create instances if classes are provided
|
||||
blocks = {}
|
||||
for name, block in blocks_dict.items():
|
||||
if inspect.isclass(block):
|
||||
blocks[name] = block()
|
||||
else:
|
||||
blocks[name] = block
|
||||
|
||||
instance.block_classes = [block.__class__ for block in blocks.values()]
|
||||
instance.block_names = list(blocks.keys())
|
||||
instance.blocks = blocks
|
||||
return instance
|
||||
|
||||
def __init__(self):
|
||||
@@ -1330,7 +1250,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
)
|
||||
|
||||
#YiYi TODO: __repr__
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
|
||||
"""
|
||||
@@ -1694,7 +1614,24 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
return result
|
||||
|
||||
@torch.compiler.disable
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
@@ -1889,19 +1826,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
r"""
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
|
||||
modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module in modules:
|
||||
return module.device
|
||||
|
||||
return torch.device("cpu")
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -2197,4 +2121,105 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
name=name,
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ModularPipeline:
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
|
||||
Args:
|
||||
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
|
||||
loader: ModularLoader, the loader to be used in the pipeline
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader):
|
||||
self.blocks = blocks
|
||||
self.loader = loader
|
||||
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.blocks.intermediates_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self.blocks(self.loader, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
|
||||
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
|
||||
self.loader.load(component_names=component_names, **kwargs)
|
||||
|
||||
def update_components(self, **kwargs):
|
||||
self.loader.update(**kwargs)
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
return ModularPipeline(blocks=blocks, loader=loader)
|
||||
|
||||
def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs):
|
||||
self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return self.blocks.doc
|
||||
@@ -1,5 +1,5 @@
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin
|
||||
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam, OutputParam
|
||||
from ..image_processor import PipelineImageInput
|
||||
from pathlib import Path
|
||||
@@ -202,7 +202,7 @@ class ModularNode(ConfigMixin):
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user