diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 1733ad6d4e..92cb50a8b4 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -184,6 +184,23 @@ class BlockState: for key, value in kwargs.items(): setattr(self, key, value) + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + def __repr__(self): def format_value(v): # Handle tensors directly @@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li for block_name, inputs in named_input_lists: for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] if (current_param.default is not None and input_param.default is not None and current_param.default != input_param.default): @@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> for block_name, outputs in named_output_lists: for output_param in outputs: - if output_param.name not in combined_dict: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param return list(combined_dict.values()) @@ -919,6 +940,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin): # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks @@ -931,6 +955,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin): @property def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1169,7 +1196,262 @@ class SequentialPipelineBlocks(ModularPipelineMixin): expected_configs=self.expected_configs ) +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 119c92e06f..7869e11a9c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -370,10 +370,10 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), ] @staticmethod @@ -982,12 +982,12 @@ class StableDiffusionXLInputStep(PipelineBlock): return [ OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), ] def check_inputs(self, components, block_state): @@ -1836,8 +1836,8 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -2025,8 +2025,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -2135,264 +2135,265 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): return components, state -class StableDiffusionXLDenoiseStep(PipelineBlock): +from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep +# class StableDiffusionXLDenoiseStep(PipelineBlock): - model_name = "stable-diffusion-xl" +# model_name = "stable-diffusion-xl" - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - InputParam("num_images_per_prompt", default=1), - ] +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - @staticmethod - def check_inputs(components, block_state): +# @staticmethod +# def check_inputs(components, block_state): - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) - block_state.num_channels_unet = components.unet.config.in_channels - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - for batch in guider_data: - components.guider.prepare_models(components.unet) +# for batch in guider_data: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - # Predict the noise residual - batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetInputStep(PipelineBlock): @@ -2452,11 +2453,11 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image", kwargs_type="contronet_kwargs"), + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -2592,353 +2593,353 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLControlNetDenoiseStep +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - model_name = "stable-diffusion-xl" +# model_name = "stable-diffusion-xl" - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator", kwargs_type="scheduler_kwargs"), - InputParam("eta", default=0.0, kwargs_type="scheduler_kwargs"), - InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) - ] +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_start", - required=True, - type_hint=float, - description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_end", - required=True, - type_hint=float, - description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") - ] +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - @staticmethod - def check_inputs(components, block_state): +# @staticmethod +# def check_inputs(components, block_state): - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value - return extra_kwargs +# return extra_kwargs - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - block_state.device = components._execution_device - print(f" block_state: {block_state}") +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") - controlnet = unwrap_module(components.controlnet) +# controlnet = unwrap_module(components.controlnet) - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - # YiYI TODO: refactor scheduler_kwargs and support unet kwargs - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (1) setup guider - # disable for LCMs - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) - # (5) Denoise loop - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): - # prepare latent input for unet - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # adjust latent input for inpainting - block_state.num_channels_unet = components.unet.config.in_channels - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - # cond_scale (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - block_state.controlnet_cond_scale = block_state.conditioning_scale - if isinstance(block_state.controlnet_cond_scale, list): - block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] - block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_state = components.guider.prepare_inputs(block_state) +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - guider_state_batch.added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } - if guider_state_batch.ip_adapter_embeds is not None: - guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - # Prepare controlnet additional conditionings - guider_state_batch.controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros - guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.conditioning_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, - return_dict=False, - **block_state.extra_controlnet_kwargs, - ) +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=guider_state_batch.added_cond_kwargs, - down_block_additional_residuals=guider_state_batch.down_block_res_samples, - mid_block_additional_residual=guider_state_batch.mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - # adjust latent for inpainting - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): @@ -3004,13 +3005,13 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images", kwargs_type="controlnet_kwargs"), + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py new file mode 100644 index 0000000000..92c07854fc --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py @@ -0,0 +1,729 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + LoopSequentialPipelineBlocks, + InputParam, + OutputParam, + BlockState, + ComponentSpec, +) +from ...guiders import ClassifierFreeGuidance +from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + + +# StableDiffusionXLControlNetDenoiseStep + +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + + +