1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

move methods to blocks

This commit is contained in:
yiyixuxu
2025-04-12 11:46:25 +02:00
parent 9ad1470d48
commit d143851309
3 changed files with 287 additions and 447 deletions

View File

@@ -346,7 +346,6 @@ class ComponentsManager:
results.update(result)
else:
results[name] = result
logger.info(f"Getting multiple components: {list(results.keys())}")
return results
else:

View File

@@ -170,7 +170,7 @@ class InputParam:
@dataclass
class OutputParam:
name: str
type_hint: Any
type_hint: Any = None
description: str = ""
def __repr__(self):
@@ -440,63 +440,31 @@ class PipelineBlock:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section
# Components section - focus only on expected components
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_component_names | loaded_components)
expected_components_str_list = []
main_components = []
auxiliary_components = []
for k in all_components:
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
for component_spec in expected_components:
component_str = f" - {component_spec.name} ({component_spec.type_hint})"
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
component_str = f" - {k}"
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
if k in getattr(self, "auxiliary_components", []):
auxiliary_components.append(component_str)
else:
main_components.append(component_str)
expected_components_str_list.append(component_str)
components = "Components:\n" + "\n".join(main_components)
if auxiliary_components:
components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components)
components = "Components:\n" + "\n".join(expected_components_str_list)
# Configs section
expected_configs = set(getattr(self, "expected_configs", []))
loaded_configs = set(self.configs.keys())
all_configs = sorted(expected_configs | loaded_configs)
configs = "Configs:\n" + "\n".join(
f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}"
for k in all_configs
)
# Configs section - focus only on expected configs
expected_configs = getattr(self, "expected_configs", [])
configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs))
# Inputs section
inputs_str = format_inputs_short(self.inputs)
@@ -672,35 +640,6 @@ class AutoPipelineBlocks:
expected_configs.append(config)
return expected_configs
# YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc
@property
def components(self):
# Combine components from all blocks
components = {}
for block_name, block in self.blocks.items():
for key, value in block.components.items():
# Only update if:
# 1. Key doesn't exist yet in components, OR
# 2. New value is not None
if key not in components or value is not None:
components[key] = value
return components
@property
def auxiliaries(self):
# Combine auxiliaries from all blocks
auxiliaries = {}
for block_name, block in self.blocks.items():
auxiliaries.update(block.auxiliaries)
return auxiliaries
@property
def configs(self):
# Combine configs from all blocks
configs = {}
for block_name, block in self.blocks.items():
configs.update(block.configs)
return configs
@property
def required_inputs(self) -> List[str]:
@@ -855,62 +794,34 @@ class AutoPipelineBlocks:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section
# Components section - focus only on expected components
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_component_names | loaded_components)
# Auxiliaries section
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
)
main_components = []
for k in all_components:
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
expected_components_str_list = []
for component_spec in expected_components:
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
component_str = f" - {k}"
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
expected_components_str_list.append(component_str)
main_components.append(component_str)
components_str = " Components:\n" + "\n".join(expected_components_str_list)
components = "Components:\n" + "\n".join(main_components)
# Configs section
expected_configs = set(getattr(self, "expected_configs", []))
loaded_configs = set(self.configs.keys())
all_configs = sorted(expected_configs | loaded_configs)
configs_str = " Configs:\n" + "\n".join(
f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items()
)
# Configs section - focus only on expected configs
expected_configs = getattr(self, "expected_configs", [])
configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name))
# Blocks section
blocks_str = " Blocks:\n"
for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block
@@ -955,6 +866,7 @@ class AutoPipelineBlocks:
blocks_str += f"{indented_intermediates}\n"
blocks_str += "\n"
# Inputs and outputs section
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
@@ -970,7 +882,6 @@ class AutoPipelineBlocks:
f"{header}\n"
f"{desc}"
f"{components_str}\n"
f"{auxiliaries_str}\n"
f"{configs_str}\n"
f"{blocks_str}\n"
f"{inputs_str}\n"
@@ -1037,35 +948,6 @@ class SequentialPipelineBlocks:
blocks[block_name] = block_cls()
self.blocks = blocks
# YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc
@property
def components(self):
# Combine components from all blocks
components = {}
for block_name, block in self.blocks.items():
for key, value in block.components.items():
# Only update if:
# 1. Key doesn't exist yet in components, OR
# 2. New value is not None
if key not in components or value is not None:
components[key] = value
return components
@property
def auxiliaries(self):
# Combine auxiliaries from all blocks
auxiliaries = {}
for block_name, block in self.blocks.items():
auxiliaries.update(block.auxiliaries)
return auxiliaries
@property
def configs(self):
# Combine configs from all blocks
configs = {}
for block_name, block in self.blocks.items():
configs.update(block.configs)
return configs
@property
def required_inputs(self) -> List[str]:
@@ -1284,63 +1166,34 @@ class SequentialPipelineBlocks:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section
# Components section - focus only on expected components
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_component_names | loaded_components)
# Auxiliaries section
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
)
main_components = []
for k in all_components:
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
expected_components_str_list = []
for component_spec in expected_components:
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
component_str = f" - {k}"
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
expected_components_str_list.append(component_str)
main_components.append(component_str)
components_str = " Components:\n" + "\n".join(expected_components_str_list)
components = "Components:\n" + "\n".join(main_components)
# Configs section
expected_configs = set(getattr(self, "expected_configs", []))
loaded_configs = set(self.configs.keys())
all_configs = sorted(expected_configs | loaded_configs)
configs_str = " Configs:\n" + "\n".join(
f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs
)
# Configs section - focus only on expected configs
expected_configs = getattr(self, "expected_configs", [])
configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name))
# Blocks section
blocks_str = " Blocks:\n"
for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block
@@ -1385,6 +1238,7 @@ class SequentialPipelineBlocks:
blocks_str += f"{indented_intermediates}\n"
blocks_str += "\n"
# Inputs and outputs section
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
@@ -1400,7 +1254,6 @@ class SequentialPipelineBlocks:
f"{header}\n"
f"{desc}"
f"{components_str}\n"
f"{auxiliaries_str}\n"
f"{configs_str}\n"
f"{blocks_str}\n"
f"{inputs_str}\n"
@@ -1408,6 +1261,7 @@ class SequentialPipelineBlocks:
f")"
)
@property
def doc(self):
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
@@ -1424,16 +1278,17 @@ class ModularPipeline(ConfigMixin):
def __init__(self, block):
self.pipeline_block = block
# add default components from pipeline_block (e.g. guider)
for key, value in block.components.items():
setattr(self, key, value)
for component_spec in self.expected_components:
if component_spec.obj is not None:
setattr(self, component_spec.name, component_spec.obj)
else:
setattr(self, component_spec.name, None)
default_configs = {}
for config_spec in self.expected_configs:
default_configs[config_spec.name] = config_spec.default
self.register_to_config(**default_configs)
# add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt)
self.register_to_config(**block.configs)
# add default auxiliaries from pipeline_block (e.g. image_processor)
for key, value in block.auxiliaries.items():
setattr(self, key, value)
@classmethod
def from_block(cls, block):
@@ -1508,9 +1363,9 @@ class ModularPipeline(ConfigMixin):
@property
def components(self):
components = {}
for name in self.expected_components:
if hasattr(self, name):
components[name] = getattr(self, name)
for component_spec in self.expected_components:
if hasattr(self, component_spec.name):
components[component_spec.name] = getattr(self, component_spec.name)
return components
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar
@@ -1596,32 +1451,32 @@ class ModularPipeline(ConfigMixin):
kwargs (dict): Keyword arguments to update the states.
"""
for component_name in self.expected_components:
if component_name in kwargs:
if hasattr(self, component_name) and getattr(self, component_name) is not None:
current_component = getattr(self, component_name)
new_component = kwargs[component_name]
for component in self.expected_components:
if component.name in kwargs:
if hasattr(self, component.name) and getattr(self, component.name) is not None:
current_component = getattr(self, component.name)
new_component = kwargs[component.name]
if not isinstance(new_component, current_component.__class__):
logger.info(
f"Overwriting existing component '{component_name}' "
f"Overwriting existing component '{component.name}' "
f"(type: {current_component.__class__.__name__}) "
f"with type: {new_component.__class__.__name__})"
)
elif isinstance(current_component, torch.nn.Module):
if id(current_component) != id(new_component):
logger.info(
f"Overwriting existing component '{component_name}' "
f"Overwriting existing component '{component.name}' "
f"(type: {type(current_component).__name__}) "
f"with new value (type: {type(new_component).__name__})"
)
setattr(self, component_name, kwargs.pop(component_name))
setattr(self, component.name, kwargs.pop(component.name))
configs_to_add = {}
for config_name in self.expected_configs:
if config_name in kwargs:
configs_to_add[config_name] = kwargs.pop(config_name)
for config in self.expected_configs:
if config.name in kwargs:
configs_to_add[config.name] = kwargs.pop(config.name)
self.register_to_config(**configs_to_add)
@property
@@ -1631,64 +1486,64 @@ class ModularPipeline(ConfigMixin):
params[input_param.name] = input_param.default
return params
def __repr__(self):
output = "ModularPipeline:\n"
output += "==============================\n\n"
# def __repr__(self):
# output = "ModularPipeline:\n"
# output += "==============================\n\n"
block = self.pipeline_block
# block = self.pipeline_block
# List the pipeline block structure first
output += "Pipeline Block:\n"
output += "--------------\n"
if hasattr(block, "blocks"):
output += f"{block.__class__.__name__}\n"
base_class = block.__class__.__bases__[0].__name__
output += f" (Class: {base_class})\n" if base_class != "object" else "\n"
for sub_block_name, sub_block in block.blocks.items():
if hasattr(block, "block_trigger_inputs"):
trigger_input = block.block_to_trigger_map[sub_block_name]
trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]"
output += f"{sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n"
else:
output += f"{sub_block_name} ({sub_block.__class__.__name__})\n"
else:
output += f"{block.__class__.__name__}\n"
output += "\n"
# # List the pipeline block structure first
# output += "Pipeline Block:\n"
# output += "--------------\n"
# if hasattr(block, "blocks"):
# output += f"{block.__class__.__name__}\n"
# base_class = block.__class__.__bases__[0].__name__
# output += f" (Class: {base_class})\n" if base_class != "object" else "\n"
# for sub_block_name, sub_block in block.blocks.items():
# if hasattr(block, "block_trigger_inputs"):
# trigger_input = block.block_to_trigger_map[sub_block_name]
# trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]"
# output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n"
# else:
# output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n"
# else:
# output += f"{block.__class__.__name__}\n"
# output += "\n"
# List the components registered in the pipeline
output += "Registered Components:\n"
output += "----------------------\n"
for name, component in self.components.items():
output += f"{name}: {type(component).__name__}"
if hasattr(component, "dtype") and hasattr(component, "device"):
output += f" (dtype={component.dtype}, device={component.device})"
output += "\n"
output += "\n"
# # List the components registered in the pipeline
# output += "Registered Components:\n"
# output += "----------------------\n"
# for name, component in self.components.items():
# output += f"{name}: {type(component).__name__}"
# if hasattr(component, "dtype") and hasattr(component, "device"):
# output += f" (dtype={component.dtype}, device={component.device})"
# output += "\n"
# output += "\n"
# List the configs registered in the pipeline
output += "Registered Configs:\n"
output += "------------------\n"
for name, config in self.config.items():
output += f"{name}: {config!r}\n"
output += "\n"
# # List the configs registered in the pipeline
# output += "Registered Configs:\n"
# output += "------------------\n"
# for name, config in self.config.items():
# output += f"{name}: {config!r}\n"
# output += "\n"
# Add auto blocks section
if hasattr(block, "trigger_inputs") and block.trigger_inputs:
output += "------------------\n"
output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n"
output += f"Trigger Inputs: {block.trigger_inputs}\n"
# Get first trigger input as example
example_input = next(t for t in block.trigger_inputs if t is not None)
output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
output += "Check `.doc` of returned object for more information.\n\n"
# # Add auto blocks section
# if hasattr(block, "trigger_inputs") and block.trigger_inputs:
# output += "------------------\n"
# output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n"
# output += f"Trigger Inputs: {block.trigger_inputs}\n"
# # Get first trigger input as example
# example_input = next(t for t in block.trigger_inputs if t is not None)
# output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
# output += "Check `.doc` of returned object for more information.\n\n"
# List the call parameters
full_doc = self.pipeline_block.doc
if "------------------------" in full_doc:
full_doc = full_doc.split("------------------------")[0].rstrip()
output += full_doc
# # List the call parameters
# full_doc = self.pipeline_block.doc
# if "------------------------" in full_doc:
# full_doc = full_doc.split("------------------------")[0].rstrip()
# output += full_doc
return output
# return output
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to

View File

@@ -22,7 +22,7 @@ from collections import OrderedDict
from ...guider import CFGGuider
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, ImageProjection
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
@@ -211,7 +211,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
@@ -237,7 +237,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
@@ -288,7 +288,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
data.do_classifier_free_guidance = data.guidance_scale > 1.0
data.device = pipeline._execution_device
data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds(
data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
pipeline,
ip_adapter_image=data.ip_adapter_image,
ip_adapter_image_embeds=None,
device=data.device,
@@ -358,8 +359,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}")
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components
def encode_prompt(
self,
components,
prompt: str,
prompt_2: Optional[str] = None,
@@ -496,7 +498,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
@@ -614,6 +616,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
data.pooled_prompt_embeds,
data.negative_pooled_prompt_embeds,
) = self.encode_prompt(
pipeline,
data.prompt,
data.prompt_2,
data.device,
@@ -670,40 +673,40 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if self.vae.config.force_upcast:
if components.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = self.vae.config.scaling_factor * image_latents
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@@ -729,7 +732,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
)
data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator)
data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator)
self.add_block_state(state, data)
@@ -776,32 +779,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if self.vae.config.force_upcast:
if components.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
@@ -809,20 +812,20 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = self.vae.config.scaling_factor * image_latents
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
@@ -844,7 +847,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
@@ -887,10 +890,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
data.batch_size = data.image.shape[0]
data.image = data.image.to(device=data.device, dtype=data.dtype)
data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator)
data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator)
# 7. Prepare mask latent variables
data.mask, data.masked_image_latents = self.prepare_mask_latents(
pipeline,
data.mask,
data.masked_image,
data.batch_size,
@@ -1067,16 +1071,16 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation")
]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components
def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
if hasattr(components.scheduler, "set_begin_index"):
components.scheduler.set_begin_index(t_start * components.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -1085,13 +1089,13 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
# that is, strength is determined by the denoising_start instead.
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
components.scheduler.config.num_train_timesteps
- (denoising_start * components.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
if components.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
@@ -1101,10 +1105,10 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
t_start = len(self.scheduler.timesteps) - num_inference_steps
timesteps = self.scheduler.timesteps[t_start:]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start)
t_start = len(components.scheduler.timesteps) - num_inference_steps
timesteps = components.scheduler.timesteps[t_start:]
if hasattr(components.scheduler, "set_begin_index"):
components.scheduler.set_begin_index(t_start)
return timesteps, num_inference_steps
@@ -1123,6 +1127,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
return isinstance(dnv, float) and 0 < dnv < 1
data.timesteps, data.num_inference_steps = self.get_timesteps(
pipeline,
data.num_inference_steps,
data.strength,
data.device,
@@ -1281,9 +1286,10 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"),
OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components
def prepare_latents_inpaint(
self,
components,
batch_size,
num_channels_latents,
height,
@@ -1302,8 +1308,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
int(height) // components.vae_scale_factor,
int(width) // components.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -1322,18 +1328,18 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None and add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
elif add_noise:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
latents = noise * components.scheduler.init_noise_sigma
else:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
@@ -1351,13 +1357,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
@@ -1379,7 +1385,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
@@ -1418,6 +1424,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor
data.latents, data.noise = self.prepare_latents_inpaint(
pipeline,
data.batch_size * data.num_images_per_prompt,
pipeline.num_channels_latents,
data.height,
@@ -1436,6 +1443,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
# 7. Prepare mask latent variables
data.mask, data.masked_image_latents = self.prepare_mask_latents(
pipeline,
data.mask,
data.masked_image_latents,
data.batch_size * data.num_images_per_prompt,
@@ -1488,10 +1496,10 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components
# YiYi TODO: refactor using _encode_vae_image
def prepare_latents_img2img(
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
@@ -1499,8 +1507,8 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None:
components.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
image = image.to(device=device, dtype=dtype)
@@ -1512,14 +1520,14 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
if components.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
components.vae.to(dtype=torch.float32)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -1536,23 +1544,23 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
if components.vae.config.force_upcast:
components.vae.to(dtype)
init_latents = init_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=device, dtype=dtype)
latents_std = latents_std.to(device=device, dtype=dtype)
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = components.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
@@ -1569,7 +1577,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
init_latents = components.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
@@ -1584,6 +1592,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
data.add_noise = True if data.denoising_start is None else False
if data.latents is None:
data.latents = self.prepare_latents_img2img(
pipeline,
data.image_latents,
data.latent_timestep,
data.batch_size,
@@ -1663,13 +1672,13 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components
def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
int(height) // components.vae_scale_factor,
int(width) // components.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -1683,7 +1692,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = latents * components.scheduler.init_noise_sigma
return latents
@@ -1702,6 +1711,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor
data.num_channels_latents = pipeline.num_channels_latents
data.latents = self.prepare_latents(
pipeline,
data.batch_size * data.num_images_per_prompt,
data.num_channels_latents,
data.height,
@@ -1762,6 +1772,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components
def _get_add_time_ids_img2img(
self,
components,
original_size,
crops_coords_top_left,
@@ -1864,7 +1875,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
if data.negative_target_size is None:
data.negative_target_size = data.target_size
data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img(
data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img(
pipeline,
data.original_size,
data.crops_coords_top_left,
data.target_size,
@@ -1946,57 +1958,24 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components
def _get_add_time_ids_img2img(
components,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype,
text_encoder_projection_dim=None,
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components
def _get_add_time_ids(
self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
if components.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
)
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features
if (
expected_add_embed_dim > passed_add_embed_dim
and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim
):
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
)
elif (
expected_add_embed_dim < passed_add_embed_dim
and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim
):
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
)
elif expected_add_embed_dim != passed_add_embed_dim:
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
return add_time_ids, add_neg_time_ids
return add_time_ids
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -2043,7 +2022,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1])
data.add_time_ids = pipeline._get_add_time_ids(
data.add_time_ids = self._get_add_time_ids(
pipeline,
data.original_size,
data.crops_coords_top_left,
data.target_size,
@@ -2051,7 +2031,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
text_encoder_projection_dim=data.text_encoder_projection_dim,
)
if data.negative_original_size is not None and data.negative_target_size is not None:
data.negative_add_time_ids = pipeline._get_add_time_ids(
data.negative_add_time_ids = self._get_add_time_ids(
pipeline,
data.negative_original_size,
data.negative_crops_coords_top_left,
data.negative_target_size,
@@ -2087,7 +2068,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("guider", CFGGuider),
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
ComponentSpec("unet", UNet2DConditionModel),
]
@@ -2231,20 +2212,20 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
" `pipeline.unet` or your `mask_image` or `image` input."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
def prepare_extra_step_kwargs(self, components, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@@ -2297,7 +2278,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta)
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
@@ -2360,12 +2341,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("guider", CFGGuider),
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
ComponentSpec("controlnet_guider", CFGGuider),
ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()),
]
@property
@@ -2519,6 +2500,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# 2. add crops_coords and resize_mode to preprocess()
def prepare_control_image(
self,
components,
image,
width,
height,
@@ -2529,9 +2511,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
crops_coords=None,
):
if crops_coords is not None:
image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -2546,20 +2528,20 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
def prepare_extra_step_kwargs(self, components, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@@ -2616,6 +2598,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# control_image
if isinstance(controlnet, ControlNetModel):
data.control_image = self.prepare_control_image(
pipeline,
image=data.control_image,
width=data.width,
height=data.height,
@@ -2630,6 +2613,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
for control_image_ in data.control_image:
control_image = self.prepare_control_image(
pipeline,
image=control_image_,
width=data.width,
height=data.height,
@@ -2712,7 +2696,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image)
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta)
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
# (5) Denoise loop
@@ -2808,8 +2792,8 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetUnionModel),
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
ComponentSpec("guider", CFGGuider),
ComponentSpec("controlnet_guider", CFGGuider),
ComponentSpec("guider", CFGGuider, obj=CFGGuider()),
ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()),
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
]
@@ -2965,6 +2949,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
# 2. add crops_coords and resize_mode to preprocess()
def prepare_control_image(
self,
components,
image,
width,
height,
@@ -2975,9 +2960,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
crops_coords=None,
):
if crops_coords is not None:
image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -2992,20 +2977,20 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
def prepare_extra_step_kwargs(self, components, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@@ -3062,6 +3047,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
# prepare control_image
for idx, _ in enumerate(data.control_image):
data.control_image[idx] = self.prepare_control_image(
pipeline,
image=data.control_image[idx],
width=data.width,
height=data.height,
@@ -3149,7 +3135,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type)
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta)
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
@@ -3266,12 +3252,12 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components
def upcast_vae(self, components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor,
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -3280,9 +3266,9 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
@@ -3293,7 +3279,7 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
if data.needs_upcasting:
self.upcast_vae()
self.upcast_vae(pipeline)
data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
elif data.latents.dtype != pipeline.vae.dtype:
if torch.backends.mps.is_available():
@@ -3672,7 +3658,7 @@ class StableDiffusionXLModularPipeline(
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
sdxl_inputs_map = {
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
@@ -3718,7 +3704,7 @@ sdxl_inputs_map = {
}
sdxl_intermediate_inputs_map = {
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
@@ -3744,7 +3730,7 @@ sdxl_intermediate_inputs_map = {
}
sdxl_intermediate_outputs_map = {
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
@@ -3769,6 +3755,6 @@ sdxl_intermediate_outputs_map = {
}
sdxl_outputs_map = {
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
}