63 KiB
Getting Started with Modular Diffusers: An Comprehensive Overview
With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you:
Write Only What's New: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.
Assemble Like LEGO®: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows.
In this guide, we will focus on how to build pipelines this way using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this guide. For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the Developer Guide that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node.
Let's get started! The Modular Diffusers Framework consists of three main components:
- ModularPipelineBlocks
- PipelineState & BlockState
- ModularPipeline
ModularPipelineBlocks
Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class ModularPipelineBlocks, including:
- [
PipelineBlock]: The most granular block - you define the computation logic. - [
SequentialPipelineBlocks]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. - [
LoopSequentialPipelineBlocks]: A special type of multi-block that forms loops. - [
AutoPipelineBlocks]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs.
All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be used standalone or combined into larger blocks. Blocks are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting.
It is very easy to use a ModularPipelineBlocks officially supported in 🧨 Diffusers
from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep
text_encoder_block = StableDiffusionXLTextEncoderStep()
This is a single PipelineBlock. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as prompt and negative_prompt, and return text embeddings outputs such as prompt_embeds and negative_prompt_embeds.
>>> text_encoder_block
StableDiffusionXLTextEncoderStep(
Class: PipelineBlock
Description: Text Encoder step that generate text_embeddings to guide the image generation
Components:
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
guider (`ClassifierFreeGuidance`)
Configs:
force_zeros_for_empty_prompt (default: True)
Inputs:
prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None
Intermediates:
- outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
)
More commonly, you can create a SequentialPipelineBlocks using a block classes preset from 🧨 Diffusers.
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
This creates a SequentialPipelineBlocks, which is a multi-block composed of other blocks. Unlike single blocks (like the text_encoder_block we saw earlier), this multi-block has a sub_blocks attribute that contains the sub-blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another.
>>> t2i_blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Components:
text_encoder (`CLIPTextModel`)
text_encoder_2 (`CLIPTextModelWithProjection`)
tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`)
guider (`ClassifierFreeGuidance`)
scheduler (`EulerDiscreteScheduler`)
unet (`UNet2DConditionModel`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
Configs:
force_zeros_for_empty_prompt (default: True)
Sub-Blocks:
[0] text_encoder (StableDiffusionXLTextEncoderStep)
Description: Text Encoder step that generate text_embeddings to guide the image generation
[1] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[2] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[3] prepare_latents (StableDiffusionXLPrepareLatentsStep)
Description: Prepare latents step that prepares the latents for the text-to-image generation process
[4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the text-to-image generation process
[5] denoise (StableDiffusionXLDenoiseStep)
Description: Denoise step that iteratively denoise the latents.
Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
At each iteration, it runs blocks defined in `sub_blocks` sequencially:
- `StableDiffusionXLLoopBeforeDenoiser`
- `StableDiffusionXLLoopDenoiser`
- `StableDiffusionXLLoopAfterDenoiser`
This block supports both text2img and img2img tasks.
[6] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
The block classes preset (TEXT2IMAGE_BLOCKS) we used is just a dictionary that maps names to ModularPipelineBlocks classes
>>> TEXT2IMAGE_BLOCKS
InsertableDict([
0: ('text_encoder', <class 'diffusers.modular_pipelines.stable_diffusion_xl.encoders.StableDiffusionXLTextEncoderStep'>),
1: ('input', <class 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLInputStep'>),
2: ('set_timesteps', <class 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLSetTimestepsStep'>),
3: ('prepare_latents', <class 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLPrepareLatentsStep'>),
4: ('prepare_add_cond', <class 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLPrepareAdditionalConditioningStep'>),
5: ('denoise', <class 'diffusers.modular_pipelines.stable_diffusion_xl.denoise.StableDiffusionXLDenoiseLoop'>),
6: ('decode', <class 'diffusers.modular_pipelines.stable_diffusion_xl.decoders.StableDiffusionXLDecodeStep'>)
])
When we create a SequentialPipelineBlocks from this preset, it instantiates each block class into actual block objects. Its sub_blocks attribute now contains these instantiated objects:
>>> t2i_blocks.sub_blocks
InsertableDict([
0: ('text_encoder', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.encoders.StableDiffusionXLTextEncoderStep'>),
1: ('input', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLInputStep'>),
2: ('set_timesteps', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLSetTimestepsStep'>),
3: ('prepare_latents', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLPrepareLatentsStep'>),
4: ('prepare_add_cond', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.before_denoise.StableDiffusionXLPrepareAdditionalConditioningStep'>),
5: ('denoise', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.denoise.StableDiffusionXLDenoiseStep'>),
6: ('decode', <obj 'diffusers.modular_pipelines.stable_diffusion_xl.decoders.StableDiffusionXLDecodeStep'>)
])
Note that both the block classes preset and the sub_blocks attribute are InsertableDict objects. This is a custom dictionary that extends OrderedDict with the ability to insert items at specific positions. You can perform all standard dictionary operations (get, set, delete) plus insert items at any index, which is particularly useful for reordering or inserting blocks in the middle of a pipeline.
Add a block:
# Add a block class to the preset
BLOCKS.insert("block_name", BlockClass, index)
# Add a block instance to the `sub_blocks` attribute
t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
Remove a block:
# remove a block class from preset
BLOCKS.pop("text_encoder")
# split out a block instance on its own
text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
Swap block:
# Replace block class in preset
BLOCKS["prepare_latents"] = CustomPrepareLatents
# Replace in sub_blocks attribute
t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
This means you can mix-and-match blocks in very flexible ways. Let's see some real examples:
Example 1: Adding IP-Adapter to the Block Classes Preset Let's make a new block classes preset by insert IP-Adapter at index 0 (before the text_encoder block), and create a text-to-image pipeline with IP-Adapter support:
from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy()
CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS)
Example 2: Extracting a block from a multi-block You can extract a block instance from the multi-block to use it independently. A common pattern is to use text_encoder to process prompts once, then reuse the text embeddings outputs to generate multiple images with different settings (schedulers, seeds, inference steps). We can do this by simply extracting the text_encoder block from the pipeline.
# this gives you StableDiffusionXLTextEncoderStep()
>>> text_encoder_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
>>> text_encoder_blocks
the multi-block now has fewer components and no longer has the text_encoder block. If you check its docstring t2i_blocks.doc, you will see that it no longer accepts prompt as input - you will need to pass the embeddings instead.
>>> t2i_blocks
SequentialPipelineBlocks(
Class: ModularPipelineBlocks
Description:
Components:
scheduler (`EulerDiscreteScheduler`)
guider (`ClassifierFreeGuidance`)
unet (`UNet2DConditionModel`)
vae (`AutoencoderKL`)
image_processor (`VaeImageProcessor`)
Blocks:
[0] input (StableDiffusionXLInputStep)
Description: Input processing step that:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
All input tensors are expected to have either batch_size=1 or match the batch_size
of prompt_embeds. The tensors will be duplicated across the batch dimension to
have a final batch_size of batch_size * num_images_per_prompt.
[1] set_timesteps (StableDiffusionXLSetTimestepsStep)
Description: Step that sets the scheduler's timesteps for inference
[2] prepare_latents (StableDiffusionXLPrepareLatentsStep)
Description: Prepare latents step that prepares the latents for the text-to-image generation process
[3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
Description: Step that prepares the additional conditioning for the text-to-image generation process
[4] denoise (StableDiffusionXLDenoiseLoop)
Description: Denoise step that iteratively denoise the latents.
Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
At each iteration, it runs blocks defined in `blocks` sequencially:
- `StableDiffusionXLLoopBeforeDenoiser`
- `StableDiffusionXLLoopDenoiser`
- `StableDiffusionXLLoopAfterDenoiser`
[5] decode (StableDiffusionXLDecodeStep)
Description: Step that decodes the denoised latents into images
)
💡 You can find all the block classes presets we support for each model in ALL_BLOCKS.
# For Stable Diffusion XL
from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
ALL_BLOCKS
# For other models...
from diffusers.modular_pipelines.<model_name> import ALL_BLOCKS
Each model provides a dictionary that maps all supported tasks/techniques to their corresponding block classes presets. For SDXL, it is
ALL_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS,
}
We will not go over how to write your own ModularPipelineBlocks but you can learn more about it here.
This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because pipeline blocks are not runnable by themselves. They are essentially "definitions" - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a ModularPipeline object.
PipelineState & BlockState
PipelineState and BlockState manage dataflow between pipeline blocks. PipelineState acts as the global state container that ModularPipelineBlocks operate on - each block gets a local view (BlockState) of the relevant variables it needs from PipelineState, performs its operations, and then updates PipelineState with any changes.
You typically don't need to manually create or manage these state objects. The ModularPipeline automatically creates and manages them for you. However, understanding their roles is important for developing custom pipeline blocks.
ModularPipeline
ModularPipeline is the main interface to create and execute pipelines in the Modular Diffusers system.
Modular Repo
ModularPipeline only works with modular repositories. You can find an example modular repo here.
Instead of using a model_index.json to configure components loading in DiffusionPipeline. Modular repositories work with modular_model_index.json. Let's walk through the difference here.
In standard model_index.json, each component entry is a (library, class) tuple:
"text_encoder": [
"transformers",
"CLIPTextModel"
],
In modular_model_index.json, each component entry contains 3 elements: (library, class, loading_specs {})
libraryandclass: Information about the actual component loaded in the pipeline at the time of saving (will benullif not loaded)loading_specs: A dictionary containing all information required to load this component, includingrepo,revision,subfolder,variant, andtype_hint.
"text_encoder": [
null, # library (same as model_index.json)
null, # class (same as model_index.json)
{ # loading specs map (unique to modular_model_index.json)
"repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo
"revision": null,
"subfolder": "text_encoder",
"type_hint": [ # (library, class) for the expected component class
"transformers",
"CLIPTextModel"
],
"variant": null
}
],
Some components may not have repo field, they cannot be loaded from a repository and can only be created with default config from the pipeline
"image_processor": [
"diffusers",
"VaeImageProcessor",
{
"type_hint": [
"diffusers",
"VaeImageProcessor"
]
}
],
Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the loading_specs dictionary. e.g. the text_encoder component will be fetched from the "text_encoder" folder in stabilityai/stable-diffusion-xl-base-1.0 while other components come from different repositories.
Creating a ModularPipeline from ModularPipelineBlocks
Each ModularPipelineBlocks has an init_pipeline method that can initialize a ModularPipeline object based on its component and configuration specifications.
Let's convert our t2i_blocks (which we created earlier) into a runnable ModularPipeline:
# We already have this from earlier
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
# Now convert it to a ModularPipeline
modular_repo_id = "YiYiXu/modular-loader-t2i"
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id)
💡 We recommend using ModularPipeline with Component Manager by passing a components_manager:
>>> components = ComponentsManager()
>>> pipeline = blocks.init_pipeline(modular_repo_id, components_manager=components)
This helps you to:
- Detect and manage duplicated models (warns when trying to register an existing model)
- Easily reuse components across different pipelines
- Apply offloading strategies across multiple pipelines
You can read more about Components Manager here
Creating a ModularPipeline with from_pretrained
You can create a ModularPipeline from a HuggingFace Hub repository with from_pretrained method, as long as it's a modular repo:
# YiYi TODO: this is not yet supported actually 😢, need to add support
from diffusers import ModularPipeline
pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...)
Loading custom code is also supported:
from diffusers import ModularPipeline
modular_repo_id = "YiYiXu/modular-diffdiff"
diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True)
Loading components into a ModularPipeline
Unlike DiffusionPipeline, when you create a ModularPipeline instance (whether using from_pretrained or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using load_components:
# This will load ALL the expected components into pipeline
import torch
t2i_pipeline.load_default_components(torch_dtype=torch.float16)
t2i_pipeline.to("cuda")
All expected components are now loaded into the pipeline. You can also partially load specific components using the names argument. For example, to only load unet and vae:
>>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
You can inspect the pipeline's loading status through its loader attribute to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the modular_model_index.json from the repository you used during init_pipeline() - it takes the loading specs that match the pipeline's component requirements.
For example, if your pipeline needs a text_encoder component, the loader will include the loading spec for text_encoder from the modular repo. If the pipeline doesn't need a component (like controlnet in a basic text-to-image pipeline), that component won't appear in the loader even if it exists in the modular repo.
The loader has the same structure as modular_model_index.json - each component entry contains the (library, class, loading_specs) format. You'll need to understand that structure to properly read the loading status below.
💡 How to read the loader:
libraryandclassfields: Show info about actually loaded components. Ifnull, the component is not loaded yet.loading_specs: If it does not haverepofield or if it isnull, the component cannot be loaded from a repository and can only be created with default config by the pipeline.
Let's inspect the t2i_pipeline.loader, you can see all the components expected to load are listed as entries in the loader. The guider and image_processor components were created using default config (their library and class field are populated, this means they are initialized, but their loading spec dict is missing loading related fields). The vae and unet components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their library, class fields are null), but you can examine their loading specs to see where they would be loaded from when you call load_components().
>>> t2i_pipeline.loader
StableDiffusionXLModularLoader {
"_class_name": "StableDiffusionXLModularLoader",
"_diffusers_version": "0.34.0.dev0",
"force_zeros_for_empty_prompt": true,
"guider": [
"diffusers",
"ClassifierFreeGuidance",
{
"type_hint": [
"diffusers",
"ClassifierFreeGuidance"
]
}
],
"image_processor": [
"diffusers",
"VaeImageProcessor",
{
"type_hint": [
"diffusers",
"VaeImageProcessor"
]
}
],
"scheduler": [
null,
null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "scheduler",
"type_hint": [
"diffusers",
"EulerDiscreteScheduler"
],
"variant": null
}
],
"text_encoder": [
null,
null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "text_encoder",
"type_hint": [
"transformers",
"CLIPTextModel"
],
"variant": null
}
],
"text_encoder_2": [
null,
null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "text_encoder_2",
"type_hint": [
"transformers",
"CLIPTextModelWithProjection"
],
"variant": null
}
],
"tokenizer": [
null,
null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "tokenizer",
"type_hint": [
"transformers",
"CLIPTokenizer"
],
"variant": null
}
],
"tokenizer_2": [
null,
null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "tokenizer_2",
"type_hint": [
"transformers",
"CLIPTokenizer"
],
"variant": null
}
],
"unet": [
"diffusers",
"UNet2DConditionModel",
{
"repo": "RunDiffusion/Juggernaut-XL-v9",
"revision": null,
"subfolder": "unet",
"type_hint": [
"diffusers",
"UNet2DConditionModel"
],
"variant": "fp16"
}
],
"vae": [
"diffusers",
"AutoencoderKL",
{
"repo": "madebyollin/sdxl-vae-fp16-fix",
"revision": null,
"subfolder": null,
"type_hint": [
"diffusers",
"AutoencoderKL"
],
"variant": null
}
]
}
There are also a few properties that can provide a quick summary of component loading status:
# All components expected by the pipeline
>>> t2i_pipeline.loader.component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
# Components that are not loaded yet (will be loaded with from_pretrained)
>>> t2i_pipeline.loader.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
# Components that will be loaded from pretrained models
>>> t2i_pipeline.loader.pretrained_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
# Components that are created with default config (no repo needed)
>>> t2i_pipeline.loader.config_component_names
['guider', 'image_processor']
Modifying Loading Specs
When you call pipeline.load_components(names=) or pipeline.load_default_components(), it uses the loading specs from the modular repository's modular_model_index.json. The pipeline's loader attribute is synced with these specs - it shows you exactly what will be loaded and from where.
You can change where components are loaded from by default by modifying the modular_model_index.json in the repository. You can change any field in the loading specs: repo, subfolder, variant, revision, etc.
# Original spec in modular_model_index.json
"unet": [
null, null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"subfolder": "unet",
"variant": "fp16"
}
]
# Modified spec - changed repo, subfolder, and variant
"unet": [
null, null,
{
"repo": "RunDiffusion/Juggernaut-XL-v9",
"subfolder": "unet",
"variant": "fp16"
}
]
When you call pipeline.load_components(...)/pipeline.load_default_components(), it will now load from the new repository by default.
Updating components in a ModularPipeline
Similar to DiffusionPipeline, You could load an components separately to replace the default one in the pipeline. But in Modular Diffusers system, you need to use ComponentSpec to load/create them.
ComponentSpec defines how to create or load components and can actually create them using its create() method (for ConfigMixin objects) or load() method (wrapper around from_pretrained()). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using ComponentSpec.from_component(). In Modular Diffusers, all pretrained models should be loaded using ComponentSpec objects.
So instead of
from diffusers import UNet2DConditionModel
import torch
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16)
You should do
from diffusers import ComponentSpec, UNet2DConditionModel
unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
unet2 = unet_spec.load(torch_dtype=torch.float16)
The key difference is that the second unet (the one we load with ComponentSpec) retains its loading specs, so you can extract and recreate it:
# to extract spec, you can do spec.load() to recreate it
>>> spec = ComponentSpec.from_component("unet", unet2)
>>> spec
ComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
To replace the unet in the pipeline
t2i_pipeline.update_components(unet=unet2)
Not only is the unet component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0". This means that if you save the pipeline now and load it back with from_pretrained, the new pipeline will by default load the SDXL original unet.
>>> t2i_pipeline.loader
StableDiffusionXLModularLoader {
...
"unet": [
"diffusers",
"UNet2DConditionModel",
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "unet",
"type_hint": [
"diffusers",
"UNet2DConditionModel"
],
"variant": "fp16"
}
],
...
}
Running a ModularPipeline
The API to run the ModularPipeline is very similar to how you would run a regular DiffusionPipeline:
>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0]
There are a few key differences though:
- You can also pass a
PipelineStateobject directly to the pipeline instead of individual arguments - If you do not specify the
outputargument, it returns thePipelineStateobject - You can pass a list as
output, e.g.pipeline(... output=["images", "latents"])will return a dictionary containing both the generated image and the final denoised latents
Under the hood, ModularPipeline's __call__ method is a wrapper around the pipeline blocks' __call__ method: it creates a PipelineState object and populates it with user inputs, then returns the output to the user based on the output argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a components input.
You can inspect the docstring of a ModularPipeline to check what arguments the pipeline accepts and how to specify the output you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list.
Important: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as control_image, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses image while SDXL controlnet img2img uses control_image.
Note: The output list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want output="images" or output="latents".
t2i_pipeline.doc
Text-to-Image, Image-to-Image, and Inpainting
These are minimum inference example for our basic tasks: text-to-image, image-to-image and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check .doc for correct input names and remember to pass output="images".
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
# create pipeline from official blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_default_components(torch_dtype=torch.float16)
pipeline.to("cuda")
# run pipeline, need to pass a "output=images" argument
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
image.save("modular_t2i_out.png")
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
# create pipeline from blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_default_components(torch_dtype=torch.float16)
pipeline.to("cuda")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
init_image = load_image(url)
prompt = "a dog catching a frisbee in the jungle"
image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
image.save("modular_i2i_out.png")
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
# create pipeline from blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_default_components(torch_dtype=torch.float16)
pipeline.to("cuda")
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
init_image = load_image(img_url)
mask_image = load_image(mask_url)
prompt = "A deep sea diver floating"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
image.save("moduar_inpaint_out.png")
ControlNet
For ControlNet, we provide one auto block you can place at the denoise step. Let's create it and inspect it to see what it tells us.
💡 How to explore new tasks: When you want to figure out how to do a specific task in Modular Diffusers, it is a good idea to start by checking what block classes presets we offer in ALL_BLOCKS. Then create the block instance and inspect it - it will show you the required components, description, and sub-blocks. This is crucial for understanding what each block does and what it needs.
>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
>>> ALL_BLOCKS["controlnet"]
InsertableDict([
0: ('denoise', <class 'diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets.StableDiffusionXLAutoControlnetStep'>)
])
>>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]()
>>> controlnet_blocks
StableDiffusionXLAutoControlnetStep(
Class: SequentialPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: {'mask', 'control_mode', 'control_image', 'controlnet_cond'}
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('mask')`).
====================================================================================================
Description: Controlnet auto step that prepare the controlnet input and denoise the latents. It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks. (it should be replace at 'denoise' step)
Components:
controlnet (`ControlNetUnionModel`)
control_image_processor (`VaeImageProcessor`)
scheduler (`EulerDiscreteScheduler`)
unet (`UNet2DConditionModel`)
guider (`ClassifierFreeGuidance`)
Sub-Blocks:
[0] controlnet_input (StableDiffusionXLAutoControlNetInputStep)
Description: Controlnet Input step that prepare the controlnet input.
This is an auto pipeline block that works for both controlnet and controlnet_union.
(it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
[1] controlnet_denoise (StableDiffusionXLAutoControlNetDenoiseStep)
Description: Denoise step that iteratively denoise the latents with controlnet. This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks.This block should not be used without a controlnet_cond input - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided. - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided. - If neither mask nor controlnet_cond are provided, step will be skipped.
)
💡 Auto Blocks: This is first time we meet a Auto Blocks! AutoPipelineBlocks automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the Auto Blocks Guide for more details.
The block shows us it has two steps (prepare inputs + denoise) and supports all tasks with both controlnet and controlnet union. Most importantly, it tells us to place it at the 'denoise' step. Let's do exactly that:
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoControlnetStep
from diffusers.utils import load_image
# create pipeline from blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
# these two lines applies controlnet
controlnet_blocks = StableDiffusionXLAutoControlnetStep()
blocks.sub_blocks["denoise"] = controlnet_blocks
Before we convert the blocks into a pipeline and load its components, let's inspect the blocks and its docs again to make sure it was assembled correctly. You should be able to see that controlnet and control_image_processor are now listed as Components, so we should initialize the pipeline with a repo that contains desired loading specs for these 2 components.
# make sure to a modular_repo including controlnet
modular_repo_id = "YiYiXu/modular-demo-auto"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_default_components(torch_dtype=torch.float16)
pipeline.to("cuda")
# generate
canny_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
image = pipeline(
prompt="a bird", controlnet_conditioning_scale=0.5, control_image=canny_image, output="images"
)[0]
image.save("modular_control_out.png")
IP-Adapter
Challenge time! Before we show you how to apply IP-adapter, try doing it yourself! Use the same process we just walked you through with ControlNet: check the official blocks preset, inspect the block instance and docstring .doc, and adapt a regular IP-adapter example to modular.
Let's walk through the steps:
- Check blocks preset
>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
>>> ALL_BLOCKS["ip_adapter"]
InsertableDict([
0: ('ip_adapter', <class 'diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets.StableDiffusionXLAutoIPAdapterStep'>)
])
- inspect the block & doc
>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
>>> ip_adapter_blocks = StableDiffusionXLAutoIPAdapterStep()
>>> ip_adapter_blocks
StableDiffusionXLAutoIPAdapterStep(
Class: AutoPipelineBlocks
====================================================================================================
This pipeline contains blocks that are selected at runtime based on inputs.
Trigger Inputs: {'ip_adapter_image'}
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
====================================================================================================
Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.
Components:
image_encoder (`CLIPVisionModelWithProjection`)
feature_extractor (`CLIPImageProcessor`)
unet (`UNet2DConditionModel`)
guider (`ClassifierFreeGuidance`)
Sub-Blocks:
• ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep)
Description: IP Adapter step that prepares ip adapter image embeddings.
Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.loader.
e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().
See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details
)
- follow the instruction to build
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
# create pipeline from official blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
# insert ip_adapter_blocks before the input step as instructed
blocks.sub_blocks.insert("ip_adapter", ip_adapter_blocks, 1)
# inspec the blocks before you convert it into pipelines,
# and make sure to use a repo that contains the loading spec for all components
# for ip-adapter, you need image_encoder & feature_extractor
modular_repo_id = "YiYiXu/modular-demo-auto"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_default_components(torch_dtype=torch.float16)
pipeline.loader.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.bin"
)
pipeline.loader.set_ip_adapter_scale(0.8)
pipeline.to("cuda")
- adapt an example to modular
We are using this one from our IP-Adapter doc!
from diffusers.utils import load_image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
image = pipeline(
prompt="a polar bear sitting in a chair drinking a milkshake",
ip_adapter_image=image,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
output="images"
)[0]
image.save("modular_ipa_out.png")
Building Advanced Workflows: The Modular Way
We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components.
In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will use IP-adapter, LoRA, and ControlNet.
Base Text-to-Image
Let's setup the text-to-image workflow. Instead of putting all blocks into one complete pipeline, we'll create separate text_blocks for encoding prompts, t2i_blocks for generating latents, and decoder_blocks for creating final images.
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
# create t2i blocks and then pop out the text_encoder step and decoder step so that we can use them in standalone manner
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
Next, convert them into runnable pipelines. We'll use a Components Manager with auto offloading strategy.
Components Manager: Create one manager and pass it to init_pipeline along with a collection name. All models loaded by that pipeline will be added to the manager under that collection.
Auto Offloading: All components are placed on CPU and only moved to device right before their forward pass. The manager monitors device memory and may move components off-device to make space for new ones. Unlike DiffusionPipeline.enable_model_cpu_offload(), this works across all components in the manager and all your workflows.
from diffusers import ComponentsManager
# Set up component manager and turn on the offloading
components = ComponentsManager()
components.enable_auto_cpu_offload(device="cuda")
Since we have a modular setup where different pipelines may share components, we recommend using a standalone loader to load components all at once and add them to each pipeline with update_components().
💡 Load components without pipeline blocks:
blocks.init_pipeline(repo)creates a pipeline with a built-in loader that only includes components its blocks needsStableDiffusionXLModularLoader.from_pretrained(repo)set up a standalone loader that includes everything in the repo'smodular_model_index.json
from diffusers import StableDiffusionXLModularLoader
t2i_repo = "YiYiXu/modular-demo-auto"
t2i_loader = StableDiffusionXLModularLoader.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
We'll load components in t2i_loader. You can get the list of all loadable components from loader's pretrained_component_names property.
>>> t2i_loader.pretrained_component_names
['controlnet', 'image_encoder', 'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae']
It include controlnet and image_encoder for ip-adapter that we don't need now. But I'll load them anyway since they'll stay on CPU and I might use them later. But you can choose what to load in the names argument.
import torch
# inspect before you load
# t2i_loader
t2i_loader.load(t2i_loader.pretrained_component_names, torch_dtype=torch.float16)
All the models are registered to components manager under the collection "t2i".
>>> components
Components:
============================================================================================================================================================
Models:
------------------------------------------------------------------------------------------------------------------------------------------------------------
Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection
------------------------------------------------------------------------------------------------------------------------------------------------------------
vae | AutoencoderKL | cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i
image_encoder | CLIPVisionModelWithProjection| cpu(cuda:0) | torch.float16| 3.44 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | t2i
text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16| 0.23 | SG161222/RealVisXL_V4.0|text_encoder|null|null | t2i
unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16| 4.78 | SG161222/RealVisXL_V4.0|unet|null|null | t2i
text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16| 1.29 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | t2i
controlnet | ControlNetModel | cpu(cuda:0) | torch.float16| 2.33 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | t2i
------------------------------------------------------------------------------------------------------------------------------------------------------------
Other Components:
------------------------------------------------------------------------------------------------------------------------------------------------------------
Name | Class | Collection
------------------------------------------------------------------------------------------------------------------------------------------------------------
tokenizer_2 | CLIPTokenizer | t2i
tokenizer | CLIPTokenizer | t2i
scheduler | EulerDiscreteScheduler | t2i
------------------------------------------------------------------------------------------------------------------------------------------------------------
Additional Component Info:
==================================================
Let's add the loaded components to each pipeline. We'll follow this pattern for each pipeline:
- Check what components the pipeline needs: inspect
pipeline.loaderor useloader.null_component_names - Get them from the components manager: use its
search_models()/get_one/get_components_from_namesmethod - Update the pipeline:
pipeline.update_components() - Verify the components are loaded correctly: inspect
pipeline.loaderas well as components manager
We will start with decoder_node. First, check what components it needs:
>>> decoder_node.loader.null_component_names
['vae']
The pipeline only needs a vae. Looking at the components manager table, there's only one VAE available:
Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection
----------------------------------------------------------------------------------------------------------------------
vae | AutoencoderKL| cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i
Since there's only one VAE, we can get it using its unique Load ID:
vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
decoder_node.update_components(vae=vae)
Verify it's correctly loaded:
decoder_node.loader
Now let's do the same for text_node. Get the list of components the pipeline needs to load:
>>> text_node.loader.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2']
Pass the list directly to the components manager to get the components and add it to the pipeline
text_components = components.get_components_by_names(text_node.loader.null_component_names)
# Add components to pipeline
text_node.update_components(**text_components)
# Verify components are loaded
assert not text_node.loader.null_component_names
text_node.loader
Finally, let's set up t2i_pipe:
# Get unet & scheduler from components manager and add to pipeline
comps = components.get_components_by_names(t2i_pipe.loader.null_component_names)
t2i_pipe.update_components(**comps)
# Verify everything is loaded
assert not t2i_pipe.loader.null_component_names
t2i_pipe.loader
# Verify components manager hasn't changed (we only reused existing components)
components
We can start to generate an image with the t2i pipeline.
First to run the prompt through text_node to get prompt embeddings
💡 don't forget to text_node.doc to find out what outputs are available and set the output argument accordingly
prompt = "an astronaut"
text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
Now generate latents with t2i pipeline and then decode with decoder.
generator = torch.Generator(device="cuda").manual_seed(0)
latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_t2i, output="images")[0]
image.save("modular_part2_t2i.png")
Lora
Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens:
t2i_loader.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
components
Notice that the "Additional Component Info" section shows that only the unet component has the LoRA adapter loaded. This means we can skip the text encoding step and reuse the existing embeddings, making the generation much faster.
Components:
============================================================================================================================================================
...
Additional Component Info:
==================================================
unet:
Adapters: ['toy_face']
🔍 Alternatively, you can find a component's ID and then use get_model_info to get detailed metadata about that component:
id = components.get_ids("unet")[0]
components.get_model_info(id)
# {'model_id': 'unet_6c2b839d-ec39-4ce9-8741-333ba6d25932', 'added_time': 1751101289.203884, 'collection': 't2i', 'class_name': 'UNet2DConditionModel', 'size_gb': 4.940812595188618, 'adapters': ['toy_face'], 'has_hook': True, 'execution_device': device(type='cuda', index=0)}
generator = torch.Generator(device="cuda").manual_seed(0)
latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_lora, output="images")[0]
image.save("modular_part2_lora.png")
IP-adapter
IP-adapter can also be used as a standalone pipeline. We can generate the embeddings once and reuse them for different workflows.
from diffusers.utils import load_image
ipa_blocks = ALL_BLOCKS["ip_adapter"]["ip_adapter"]()
ipa_node = ipa_blocks.init_pipeline(t2i_repo, components_manager=components)
comps = components.get_components_by_names(ipa_node.loader.null_component_names)
ipa_node.update_components(**comps)
t2i_loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
t2i_loader.set_ip_adapter_scale(0.6)
# check it's correctly loaded
assert not ipa_node.loader.null_component_names
ipa_node.loader
# find out inputs/outputs
print(ipa_node.doc)
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
ipa_embeddings = ipa_node(ip_adapter_image=ip_adapter_image, output=["ip_adapter_embeds","negative_ip_adapter_embeds"])
generator = torch.Generator(device="cuda").manual_seed(0)
latents_ipa = t2i_pipe(**text_embeddings, **ipa_embeddings, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_ipa, output="images")[0]
image.save("modular_part2_lora_ipa.png")
ControlNet
We can create a new ControlNet workflow by modifying the pipeline blocks, reusing components as much as possible, and see how it affects the generation.
We want to use a different ControlNet from the one that's already loaded.
from diffusers import ComponentSpec, ControlNetModel
control_blocks = ALL_BLOCKS["controlnet"]["denoise"]()
# update the t2i_blocks and create pipeline
t2i_blocks.sub_blocks["denoise"] = control_blocks
t2i_control_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
# fetch the controlnet_pose seperately since we need to change name when adding it to the pipeline
controlnet_spec = ComponentSpec(name="controlnet_pose", type_hint=ControlNetModel, repo="thibaud/controlnet-openpose-sdxl-1.0")
controlnet = controlnet_spec.load(torch_dtype=torch.float16)
t2i_control_pipe.update_components(controlnet=controlnet)
# fetch the rest of the components from the components manager
comps = components.get_components_by_names(t2i_control_pipe.loader.null_component_names)
t2i_control_pipe.update_components(**comps)
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png")
generator = torch.Generator(device="cuda").manual_seed(0)
latents_control = t2i_control_pipe(**text_embeddings, **ipa_embeddings, control_image=control_image, num_inference_steps=25, generator=generator, output="latents")
image = decoder_node(latents=latents_control, output="images")[0]
image.save("modular_part2_lora_ipa_control.png")
Now set up refiner workflow. For refiner blocks, we removed image_encoder since the refiner works with latents directly, and decoder since we already have a dedicated one. We keep text_encoder because SDXL refiner encodes text prompts differently from the text-to-image pipeline, so we cannot share it.
# Create a refiner blocks
# - removing image_encoder a since we'll use latents from t2i
# - removing decode since we already created a seperate decoder_block
refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
refiner_blocks.sub_blocks.pop("image_encoder")
refiner_blocks.sub_blocks.pop("decode")
Refiner
Create refiner pipeline. refiner has a different unet and use only one text_encoder so it is hosted in a different repo. We pass the same components manager to refiner pipeline, along with a unique "refiner" collection.
refiner_repo = "YiYiXu/modular_refiner"
refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
We want to reuse components from the t2i pipeline in the refiner as much as possible. First, let's check the loading status of the refiner pipeline to understand what components are needed:
>>> refiner_pipe.loader
Looking at the loader output, you can see that text_encoder and tokenizer have empty loading spec maps (their repo fields are null), this is because refiner pipeline does not use these two components so they are not listed in the modular_model_index.json in refiner_repo. The unet is different from the one we loaded for text-to-image. The remaining components: vae, text_encoder_2, tokenizer_2, and scheduler are already available in the t2i collection, we can reuse them instead of loading duplicates.
refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
# verify loaded correctly
refiner_pipe.loader
# veryfiy registered to components manager under refiner
components
Now let's reuse the components from the t2i pipeline in the refiner. We use the| to select multiple components from components manager at once:
# Reuse components from t2i pipeline (select everything at once)
reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
refiner_pipe.update_components(**reuse_components)
You'll see warnings indicating that these components already exist in the components manager:
component 'text_encoder_2' already exists as 'text_encoder_2_238ae9a7-c864-4837-a8a2-f58ed753b2d0'
component 'tokenizer_2' already exists as 'tokenizer_2_b795af3d-f048-4b07-a770-9e8237a2be2d'
component 'scheduler' already exists as 'scheduler_e3435f63-266a-4427-9383-eb812e830fe8'
component 'vae' already exists as 'vae_357eee6a-4a06-46f1-be83-494f7d60ca69'
These warnings are expected and indicate that the components manager is correctly identifying that these components are already loaded. The system will reuse the existing components rather than creating duplicates.
Let's check the components manager again to see the updated state. You should see text_encoder_2, vae, tokenizer_2, and scheduler now appear under both "t2i" and "refiner" collections.
Now let's refine!
# refine the latents from base text-to-image workflow
refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_t2i_refine_out.png")
# refine the latents from the text-to-image lora workflow
refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_lora_refine_out.png")
# refine the latents from the text-to-image + lora + ip-adapter workflow
refined_latents = refiner_pipe(image_latents=latents_ipa, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_ipa_refine_out.png")
# refine the latents from the text-to-image + lora + ip-adapter + controlnet workflow
refined_latents = refiner_pipe(image_latents=latents_control, prompt=prompt, num_inference_steps=10, output="latents")
refined_image = decoder_node(latents=refined_latents, output="images")[0]
refined_image.save("modular_part2_control_refine_out.png")
Results
Here are the results from our modular pipeline examples.
Base Text-to-Image Generation
| Base Text-to-Image | Base Text-to-Image (Refined) |
|---|---|
![]() |
![]() |
LoRA
| LoRA | LoRA (Refined) |
|---|---|
![]() |
![]() |
LoRA + IP-Adapter
| LoRA + IP-Adapter | LoRA + IP-Adapter (Refined) |
|---|---|
![]() |
![]() |
ControlNet + LoRA + IP-Adapter
| ControlNet + LoRA + IP-Adapter | ControlNet + LoRA + IP-Adapter (Refined) |
|---|---|
![]() |
![]() |







