From d143851309c7eed3ddb3af54fc56943452cf79d5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 12 Apr 2025 11:46:25 +0200 Subject: [PATCH] move methods to blocks --- src/diffusers/pipelines/components_manager.py | 1 - src/diffusers/pipelines/modular_pipeline.py | 423 ++++++------------ .../pipeline_stable_diffusion_xl_modular.py | 310 ++++++------- 3 files changed, 287 insertions(+), 447 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 6d7665e292..8c14321ccf 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -346,7 +346,6 @@ class ComponentsManager: results.update(result) else: results[name] = result - logger.info(f"Getting multiple components: {list(results.keys())}") return results else: diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 954b78d417..785f38cdbf 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -170,7 +170,7 @@ class InputParam: @dataclass class OutputParam: name: str - type_hint: Any + type_hint: Any = None description: str = "" def __repr__(self): @@ -440,63 +440,31 @@ class PipelineBlock: desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) + expected_components_str_list = [] - main_components = [] - auxiliary_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + for component_spec in expected_components: + component_str = f" - {component_spec.name} ({component_spec.type_hint})" - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" - if k in getattr(self, "auxiliary_components", []): - auxiliary_components.append(component_str) - else: - main_components.append(component_str) + expected_components_str_list.append(component_str) - components = "Components:\n" + "\n".join(main_components) - if auxiliary_components: - components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components) + components = "Components:\n" + "\n".join(expected_components_str_list) - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs = "Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" - for k in all_configs - ) + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -672,35 +640,6 @@ class AutoPipelineBlocks: expected_configs.append(config) return expected_configs - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs @property def required_inputs(self) -> List[str]: @@ -855,62 +794,34 @@ class AutoPipelineBlocks: desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - main_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + expected_components_str_list = [] + + for component_spec in expected_components: - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" + + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" + expected_components_str_list.append(component_str) - main_components.append(component_str) + components_str = " Components:\n" + "\n".join(expected_components_str_list) - components = "Components:\n" + "\n".join(main_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items() - ) + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + # Blocks section blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -955,6 +866,7 @@ class AutoPipelineBlocks: blocks_str += f"{indented_intermediates}\n" blocks_str += "\n" + # Inputs and outputs section inputs_str = format_inputs_short(self.inputs) inputs_str = " Inputs:\n " + inputs_str outputs = [out.name for out in self.outputs] @@ -970,7 +882,6 @@ class AutoPipelineBlocks: f"{header}\n" f"{desc}" f"{components_str}\n" - f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" f"{inputs_str}\n" @@ -1037,35 +948,6 @@ class SequentialPipelineBlocks: blocks[block_name] = block_cls() self.blocks = blocks - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs @property def required_inputs(self) -> List[str]: @@ -1284,63 +1166,34 @@ class SequentialPipelineBlocks: desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - main_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + expected_components_str_list = [] + + for component_spec in expected_components: - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" + + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" + expected_components_str_list.append(component_str) - main_components.append(component_str) + components_str = " Components:\n" + "\n".join(expected_components_str_list) - components = "Components:\n" + "\n".join(main_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + # Blocks section blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1385,6 +1238,7 @@ class SequentialPipelineBlocks: blocks_str += f"{indented_intermediates}\n" blocks_str += "\n" + # Inputs and outputs section inputs_str = format_inputs_short(self.inputs) inputs_str = " Inputs:\n " + inputs_str outputs = [out.name for out in self.outputs] @@ -1400,7 +1254,6 @@ class SequentialPipelineBlocks: f"{header}\n" f"{desc}" f"{components_str}\n" - f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" f"{inputs_str}\n" @@ -1408,6 +1261,7 @@ class SequentialPipelineBlocks: f")" ) + @property def doc(self): return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) @@ -1424,16 +1278,17 @@ class ModularPipeline(ConfigMixin): def __init__(self, block): self.pipeline_block = block - # add default components from pipeline_block (e.g. guider) - for key, value in block.components.items(): - setattr(self, key, value) + for component_spec in self.expected_components: + if component_spec.obj is not None: + setattr(self, component_spec.name, component_spec.obj) + else: + setattr(self, component_spec.name, None) + + default_configs = {} + for config_spec in self.expected_configs: + default_configs[config_spec.name] = config_spec.default + self.register_to_config(**default_configs) - # add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt) - self.register_to_config(**block.configs) - - # add default auxiliaries from pipeline_block (e.g. image_processor) - for key, value in block.auxiliaries.items(): - setattr(self, key, value) @classmethod def from_block(cls, block): @@ -1508,9 +1363,9 @@ class ModularPipeline(ConfigMixin): @property def components(self): components = {} - for name in self.expected_components: - if hasattr(self, name): - components[name] = getattr(self, name) + for component_spec in self.expected_components: + if hasattr(self, component_spec.name): + components[component_spec.name] = getattr(self, component_spec.name) return components # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar @@ -1596,32 +1451,32 @@ class ModularPipeline(ConfigMixin): kwargs (dict): Keyword arguments to update the states. """ - for component_name in self.expected_components: - if component_name in kwargs: - if hasattr(self, component_name) and getattr(self, component_name) is not None: - current_component = getattr(self, component_name) - new_component = kwargs[component_name] + for component in self.expected_components: + if component.name in kwargs: + if hasattr(self, component.name) and getattr(self, component.name) is not None: + current_component = getattr(self, component.name) + new_component = kwargs[component.name] if not isinstance(new_component, current_component.__class__): logger.info( - f"Overwriting existing component '{component_name}' " + f"Overwriting existing component '{component.name}' " f"(type: {current_component.__class__.__name__}) " f"with type: {new_component.__class__.__name__})" ) elif isinstance(current_component, torch.nn.Module): if id(current_component) != id(new_component): logger.info( - f"Overwriting existing component '{component_name}' " + f"Overwriting existing component '{component.name}' " f"(type: {type(current_component).__name__}) " f"with new value (type: {type(new_component).__name__})" ) - setattr(self, component_name, kwargs.pop(component_name)) + setattr(self, component.name, kwargs.pop(component.name)) configs_to_add = {} - for config_name in self.expected_configs: - if config_name in kwargs: - configs_to_add[config_name] = kwargs.pop(config_name) + for config in self.expected_configs: + if config.name in kwargs: + configs_to_add[config.name] = kwargs.pop(config.name) self.register_to_config(**configs_to_add) @property @@ -1631,64 +1486,64 @@ class ModularPipeline(ConfigMixin): params[input_param.name] = input_param.default return params - def __repr__(self): - output = "ModularPipeline:\n" - output += "==============================\n\n" + # def __repr__(self): + # output = "ModularPipeline:\n" + # output += "==============================\n\n" - block = self.pipeline_block + # block = self.pipeline_block - # List the pipeline block structure first - output += "Pipeline Block:\n" - output += "--------------\n" - if hasattr(block, "blocks"): - output += f"{block.__class__.__name__}\n" - base_class = block.__class__.__bases__[0].__name__ - output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - for sub_block_name, sub_block in block.blocks.items(): - if hasattr(block, "block_trigger_inputs"): - trigger_input = block.block_to_trigger_map[sub_block_name] - trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - else: - output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - else: - output += f"{block.__class__.__name__}\n" - output += "\n" + # # List the pipeline block structure first + # output += "Pipeline Block:\n" + # output += "--------------\n" + # if hasattr(block, "blocks"): + # output += f"{block.__class__.__name__}\n" + # base_class = block.__class__.__bases__[0].__name__ + # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" + # for sub_block_name, sub_block in block.blocks.items(): + # if hasattr(block, "block_trigger_inputs"): + # trigger_input = block.block_to_trigger_map[sub_block_name] + # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" + # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" + # else: + # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" + # else: + # output += f"{block.__class__.__name__}\n" + # output += "\n" - # List the components registered in the pipeline - output += "Registered Components:\n" - output += "----------------------\n" - for name, component in self.components.items(): - output += f"{name}: {type(component).__name__}" - if hasattr(component, "dtype") and hasattr(component, "device"): - output += f" (dtype={component.dtype}, device={component.device})" - output += "\n" - output += "\n" + # # List the components registered in the pipeline + # output += "Registered Components:\n" + # output += "----------------------\n" + # for name, component in self.components.items(): + # output += f"{name}: {type(component).__name__}" + # if hasattr(component, "dtype") and hasattr(component, "device"): + # output += f" (dtype={component.dtype}, device={component.device})" + # output += "\n" + # output += "\n" - # List the configs registered in the pipeline - output += "Registered Configs:\n" - output += "------------------\n" - for name, config in self.config.items(): - output += f"{name}: {config!r}\n" - output += "\n" + # # List the configs registered in the pipeline + # output += "Registered Configs:\n" + # output += "------------------\n" + # for name, config in self.config.items(): + # output += f"{name}: {config!r}\n" + # output += "\n" - # Add auto blocks section - if hasattr(block, "trigger_inputs") and block.trigger_inputs: - output += "------------------\n" - output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - output += f"Trigger Inputs: {block.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in block.trigger_inputs if t is not None) - output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - output += "Check `.doc` of returned object for more information.\n\n" + # # Add auto blocks section + # if hasattr(block, "trigger_inputs") and block.trigger_inputs: + # output += "------------------\n" + # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" + # output += f"Trigger Inputs: {block.trigger_inputs}\n" + # # Get first trigger input as example + # example_input = next(t for t in block.trigger_inputs if t is not None) + # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + # output += "Check `.doc` of returned object for more information.\n\n" - # List the call parameters - full_doc = self.pipeline_block.doc - if "------------------------" in full_doc: - full_doc = full_doc.split("------------------------")[0].rstrip() - output += full_doc + # # List the call parameters + # full_doc = self.pipeline_block.doc + # if "------------------------" in full_doc: + # full_doc = full_doc.split("------------------------")[0].rstrip() + # output += full_doc - return output + # return output # YiYi TODO: try to unify the to method with the one in DiffusionPipeline # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 23ea96b8e8..8e71093089 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -22,7 +22,7 @@ from collections import OrderedDict from ...guider import CFGGuider from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.lora import adjust_lora_scale_text_encoder from ...utils import ( @@ -211,7 +211,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): ] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(components.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): @@ -237,7 +237,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: @@ -288,7 +288,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): data.do_classifier_free_guidance = data.guidance_scale > 1.0 data.device = pipeline._execution_device - data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( + data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + pipeline, ip_adapter_image=data.ip_adapter_image, ip_adapter_image_embeds=None, device=data.device, @@ -358,8 +359,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components def encode_prompt( + self, components, prompt: str, prompt_2: Optional[str] = None, @@ -496,7 +498,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) @@ -614,6 +616,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds, ) = self.encode_prompt( + pipeline, data.prompt, data.prompt_2, data.device, @@ -670,40 +673,40 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) dtype = image.dtype - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) image_latents = image_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std else: - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = components.vae.config.scaling_factor * image_latents return image_latents @@ -729,7 +732,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): ) - data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator) self.add_block_state(state, data) @@ -776,32 +779,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) dtype = image.dtype - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) image_latents = image_latents.to(dtype) if latents_mean is not None and latents_std is not None: @@ -809,20 +812,20 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): latents_std = latents_std.to(device=image_latents.device, dtype=dtype) image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = components.vae.config.scaling_factor * image_latents return image_latents # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) @@ -844,7 +847,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -887,10 +890,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): data.batch_size = data.image.shape[0] data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator) # 7. Prepare mask latent variables data.mask, data.masked_image_latents = self.prepare_mask_latents( + pipeline, data.mask, data.masked_image, data.batch_size, @@ -1067,16 +1071,16 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) return timesteps, num_inference_steps - t_start @@ -1085,13 +1089,13 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would @@ -1101,10 +1105,10 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) return timesteps, num_inference_steps @@ -1123,6 +1127,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): return isinstance(dnv, float) and 0 < dnv < 1 data.timesteps, data.num_inference_steps = self.get_timesteps( + pipeline, data.num_inference_steps, data.strength, data.device, @@ -1281,9 +1286,10 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components def prepare_latents_inpaint( self, + components, batch_size, num_channels_latents, height, @@ -1302,8 +1308,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): shape = ( batch_size, num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1322,18 +1328,18 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma + latents = noise * components.scheduler.init_noise_sigma else: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) @@ -1351,13 +1357,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) @@ -1379,7 +1385,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -1418,6 +1424,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor data.latents, data.noise = self.prepare_latents_inpaint( + pipeline, data.batch_size * data.num_images_per_prompt, pipeline.num_channels_latents, data.height, @@ -1436,6 +1443,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): # 7. Prepare mask latent variables data.mask, data.masked_image_latents = self.prepare_mask_latents( + pipeline, data.mask, data.masked_image_latents, data.batch_size * data.num_images_per_prompt, @@ -1488,10 +1496,10 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components # YiYi TODO: refactor using _encode_vae_image def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -1499,8 +1507,8 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ) # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") + if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None: + components.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) @@ -1512,14 +1520,14 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): else: latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1536,23 +1544,23 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ) init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std else: - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = components.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -1569,7 +1577,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + init_latents = components.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents @@ -1584,6 +1592,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): data.add_noise = True if data.denoising_start is None else False if data.latents is None: data.latents = self.prepare_latents_img2img( + pipeline, data.image_latents, data.latent_timestep, data.batch_size, @@ -1663,13 +1672,13 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1683,7 +1692,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + latents = latents * components.scheduler.init_noise_sigma return latents @@ -1702,6 +1711,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor data.num_channels_latents = pipeline.num_channels_latents data.latents = self.prepare_latents( + pipeline, data.batch_size * data.num_images_per_prompt, data.num_channels_latents, data.height, @@ -1762,6 +1772,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components def _get_add_time_ids_img2img( + self, components, original_size, crops_coords_top_left, @@ -1864,7 +1875,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): if data.negative_target_size is None: data.negative_target_size = data.target_size - data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img( + data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img( + pipeline, data.original_size, data.crops_coords_top_left, data.target_size, @@ -1946,57 +1958,24 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - def _get_add_time_ids_img2img( - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + def _get_add_time_ids( + self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: + if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids + return add_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( @@ -2043,7 +2022,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - data.add_time_ids = pipeline._get_add_time_ids( + data.add_time_ids = self._get_add_time_ids( + pipeline, data.original_size, data.crops_coords_top_left, data.target_size, @@ -2051,7 +2031,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): text_encoder_projection_dim=data.text_encoder_projection_dim, ) if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = pipeline._get_add_time_ids( + data.negative_add_time_ids = self._get_add_time_ids( + pipeline, data.negative_original_size, data.negative_crops_coords_top_left, data.negative_target_size, @@ -2087,7 +2068,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2231,20 +2212,20 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): " `pipeline.unet` or your `mask_image` or `image` input." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -2297,7 +2278,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: @@ -2360,12 +2341,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", CFGGuider), + ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), ] @property @@ -2519,6 +2500,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # 2. add crops_coords and resize_mode to preprocess() def prepare_control_image( self, + components, image, width, height, @@ -2529,9 +2511,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): crops_coords=None, ): if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -2546,20 +2528,20 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -2616,6 +2598,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # control_image if isinstance(controlnet, ControlNetModel): data.control_image = self.prepare_control_image( + pipeline, image=data.control_image, width=data.width, height=data.height, @@ -2630,6 +2613,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): for control_image_ in data.control_image: control_image = self.prepare_control_image( + pipeline, image=control_image_, width=data.width, height=data.height, @@ -2712,7 +2696,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) # (5) Denoise loop @@ -2808,8 +2792,8 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", CFGGuider), - ComponentSpec("controlnet_guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2965,6 +2949,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): # 2. add crops_coords and resize_mode to preprocess() def prepare_control_image( self, + components, image, width, height, @@ -2975,9 +2960,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): crops_coords=None, ): if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -2992,20 +2977,20 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -3062,6 +3047,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): # prepare control_image for idx, _ in enumerate(data.control_image): data.control_image[idx] = self.prepare_control_image( + pipeline, image=data.control_image[idx], width=data.width, height=data.height, @@ -3149,7 +3135,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) @@ -3266,12 +3252,12 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + def upcast_vae(self, components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, + components.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, @@ -3280,9 +3266,9 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -3293,7 +3279,7 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast if data.needs_upcasting: - self.upcast_vae() + self.upcast_vae(pipeline) data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) elif data.latents.dtype != pipeline.vae.dtype: if torch.backends.mps.is_available(): @@ -3672,7 +3658,7 @@ class StableDiffusionXLModularPipeline( # YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -sdxl_inputs_map = { +SDXL_INPUTS_SCHEMA = { "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), @@ -3718,7 +3704,7 @@ sdxl_inputs_map = { } -sdxl_intermediate_inputs_map = { +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), @@ -3744,7 +3730,7 @@ sdxl_intermediate_inputs_map = { } -sdxl_intermediate_outputs_map = { +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), @@ -3769,6 +3755,6 @@ sdxl_intermediate_outputs_map = { } -sdxl_outputs_map = { +SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } \ No newline at end of file