mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
attemp to break ModularPipeline base into componentstate and a pipelineblock mixin
This commit is contained in:
@@ -1021,37 +1021,111 @@ class SequentialPipelineBlocks:
|
||||
expected_configs=self.expected_configs
|
||||
)
|
||||
|
||||
class ModularPipeline(ConfigMixin):
|
||||
|
||||
|
||||
class ModularPipelineMixin:
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.components_manager = None
|
||||
self.components_manager_prefix = ""
|
||||
self.components_state = None
|
||||
|
||||
# YiYi TODO: not sure this is the best method name
|
||||
def compile(self, components_manager: ComponentsManager, label: Optional[str] = None):
|
||||
self.components_manager = components_manager
|
||||
self.components_manager_prefix = "" if label is None else f"{label}_"
|
||||
self.components_state = ComponentsState(self.expected_components, self.expected_configs)
|
||||
|
||||
components_to_add = self.components_manager.get(f"{self.components_manager_prefix}*")
|
||||
self.components_state.update_states(self.expected_components, self.expected_configs, **components_to_add)
|
||||
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
input_params = kwargs.copy()
|
||||
|
||||
default_params = self.default_call_parameters
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
|
||||
for name, default in default_params.items():
|
||||
if name in input_params:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, input_params.pop(name))
|
||||
else:
|
||||
state.add_input(name, input_params[name])
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default)
|
||||
|
||||
for name in intermediates_inputs:
|
||||
if name in input_params:
|
||||
state.add_intermediate(name, input_params.pop(name))
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(input_params) > 0:
|
||||
logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
|
||||
class ComponentsState(ConfigMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
|
||||
"""
|
||||
|
||||
config_name = "model_index.json"
|
||||
_exclude_from_cpu_offload = []
|
||||
|
||||
def __init__(self, block):
|
||||
self.pipeline_block = block
|
||||
def __init__(self, component_specs, config_specs):
|
||||
|
||||
for component_spec in self.expected_components:
|
||||
for component_spec in component_specs:
|
||||
if component_spec.obj is not None:
|
||||
setattr(self, component_spec.name, component_spec.obj)
|
||||
else:
|
||||
setattr(self, component_spec.name, None)
|
||||
|
||||
default_configs = {}
|
||||
for config_spec in self.expected_configs:
|
||||
for config_spec in config_specs:
|
||||
default_configs[config_spec.name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_block(cls, block):
|
||||
modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name]
|
||||
modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name)
|
||||
|
||||
return modular_pipeline_class(block)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
r"""
|
||||
@@ -1089,10 +1163,7 @@ class ModularPipeline(ConfigMixin):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
|
||||
def get_execution_blocks(self, *trigger_inputs):
|
||||
return self.pipeline_block.get_execution_blocks(*trigger_inputs)
|
||||
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
r"""
|
||||
@@ -1107,13 +1178,6 @@ class ModularPipeline(ConfigMixin):
|
||||
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return self.pipeline_block.expected_components
|
||||
|
||||
@property
|
||||
def expected_configs(self):
|
||||
return self.pipeline_block.expected_configs
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
@@ -1123,80 +1187,7 @@ class ModularPipeline(ConfigMixin):
|
||||
components[component_spec.name] = getattr(self, component_spec.name)
|
||||
return components
|
||||
|
||||
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
input_params = kwargs.copy()
|
||||
|
||||
default_params = self.default_call_parameters
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs]
|
||||
for name, default in default_params.items():
|
||||
if name in input_params:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, input_params.pop(name))
|
||||
else:
|
||||
state.add_input(name, input_params[name])
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default)
|
||||
|
||||
for name in intermediates_inputs:
|
||||
if name in input_params:
|
||||
state.add_intermediate(name, input_params.pop(name))
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(input_params) > 0:
|
||||
logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self.pipeline_block(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def update_states(self, **kwargs):
|
||||
def update_states(self, expected_components, expected_configs, **kwargs):
|
||||
"""
|
||||
Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for
|
||||
each pipeline block, does not need to be updated by users. Logs if existing non-None components are being
|
||||
@@ -1206,7 +1197,7 @@ class ModularPipeline(ConfigMixin):
|
||||
kwargs (dict): Keyword arguments to update the states.
|
||||
"""
|
||||
|
||||
for component in self.expected_components:
|
||||
for component in expected_components:
|
||||
if component.name in kwargs:
|
||||
if hasattr(self, component.name) and getattr(self, component.name) is not None:
|
||||
current_component = getattr(self, component.name)
|
||||
@@ -1226,163 +1217,14 @@ class ModularPipeline(ConfigMixin):
|
||||
f"with new value (type: {type(new_component).__name__})"
|
||||
)
|
||||
|
||||
setattr(self, component.name, kwargs.pop(component.name))
|
||||
setattr(self.components_state, component.name, kwargs.pop(component.name))
|
||||
|
||||
configs_to_add = {}
|
||||
for config in self.expected_configs:
|
||||
for config in expected_configs:
|
||||
if config.name in kwargs:
|
||||
configs_to_add[config.name] = kwargs.pop(config.name)
|
||||
self.register_to_config(**configs_to_add)
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.pipeline_block.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
|
||||
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
|
||||
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
|
||||
# YiYi TODO: should support to method
|
||||
def to(self, *args, **kwargs):
|
||||
r"""
|
||||
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
|
||||
arguments of `self.to(*args, **kwargs).`
|
||||
|
||||
<Tip>
|
||||
|
||||
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
|
||||
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
Here are the ways to call `to`:
|
||||
|
||||
- `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
- `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
|
||||
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
|
||||
- `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
|
||||
specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
|
||||
Arguments:
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
Returns a pipeline with the specified
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
device (`torch.Device`, *optional*):
|
||||
Returns a pipeline with the specified
|
||||
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
|
||||
silence_dtype_warnings (`str`, *optional*, defaults to `False`):
|
||||
Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
|
||||
|
||||
Returns:
|
||||
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
device = kwargs.pop("device", None)
|
||||
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
|
||||
|
||||
dtype_arg = None
|
||||
device_arg = None
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
dtype_arg = args[0]
|
||||
else:
|
||||
device_arg = torch.device(args[0]) if args[0] is not None else None
|
||||
elif len(args) == 2:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
raise ValueError(
|
||||
"When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
|
||||
)
|
||||
device_arg = torch.device(args[0]) if args[0] is not None else None
|
||||
dtype_arg = args[1]
|
||||
elif len(args) > 2:
|
||||
raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
|
||||
|
||||
if dtype is not None and dtype_arg is not None:
|
||||
raise ValueError(
|
||||
"You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
|
||||
)
|
||||
|
||||
dtype = dtype or dtype_arg
|
||||
|
||||
if device is not None and device_arg is not None:
|
||||
raise ValueError(
|
||||
"You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
|
||||
)
|
||||
|
||||
device = device or device_arg
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
def module_is_sequentially_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
or hasattr(module._hf_hook, "hooks")
|
||||
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for module in modules:
|
||||
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
|
||||
|
||||
if is_loaded_in_8bit and dtype is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit and device is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
||||
)
|
||||
else:
|
||||
module.to(device, dtype)
|
||||
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
return self
|
||||
pass
|
||||
|
||||
@@ -3571,7 +3571,7 @@ SDXL_SUPPORTED_BLOCKS = {
|
||||
|
||||
|
||||
# YiYi TODO: rename to components etc. and not inherit from ModularPipeline
|
||||
class StableDiffusionXLModularPipeline(
|
||||
class StableDiffusionXLComponentStates(
|
||||
ModularPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
|
||||
Reference in New Issue
Block a user