1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/docs/source/en/modular_diffusers/quickstart.md
2026-01-26 21:36:24 +01:00

10 KiB

Quickstart

Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [ModularPipelineBlocks] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [ModularPipeline], a friendly user-facing interface for running generation tasks.

This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.

Run a pipeline

[ModularPipeline] is the main interface for loading, running, and managing modular pipelines.

import torch
from diffusers import ModularPipeline

pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image")
pipe.load_components(torch_dtype=torch.bfloat16)
pipe.to("cuda")

image = pipe(
    prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
).images[0]
image

[~ModularPipeline.from_pretrained] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [~ModularPipeline.load_components]. This gives you control over when and how components are loaded.

Learn more about creating and loading pipelines in the Creating a pipeline and Loading components guides.

Understand the structure

A [ModularPipeline] has two parts:

  • State: the loaded components (models, schedulers, processors) and configuration
  • Definition: the [ModularPipelineBlocks] that specify inputs, outputs, expected components and computation logic

The blocks define what the pipeline does. Access them through pipe.blocks.

print(pipe.blocks)
QwenImageAutoBlocks(
  Class: SequentialPipelineBlocks

  Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
      
      Supported workflows:
        - `text2image`: requires `prompt`
        - `image2image`: requires `prompt`, `image`
        - `inpainting`: requires `prompt`, `mask_image`, `image`
        - `controlnet_text2image`: requires `prompt`, `control_image`
        ...

  Components:
      text_encoder (`Qwen2_5_VLForConditionalGeneration`)
      vae (`AutoencoderKLQwenImage`)
      transformer (`QwenImageTransformer2DModel`)
      ...

  Sub-Blocks:
    [0] text_encoder (QwenImageAutoTextEncoderStep)
    [1] vae_encoder (QwenImageAutoVaeEncoderStep)
    [2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)
    [3] denoise (QwenImageAutoCoreDenoiseStep)
    [4] decode (QwenImageAutoDecodeStep)
)

The output returns:

  • The supported workflows (text2image, image2image, inpainting, etc.)
  • The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode)

Workflows

QwenImageAutoBlocks is a [ConditionalPipelineBlocks], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass image to the pipeline, it runs an image-to-image workflow instead of text-to-image.

from diffusers.utils import load_image

input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")

image = pipe(
    prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
    image=input_image,
).images[0]

Use get_workflow() to extract the blocks for a specific workflow.

img2img_blocks = pipe.blocks.get_workflow("image2image")

Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the AutoPipelineBlocks guide.

Sub-blocks

QwenImageAutoBlocks is itself composed of smaller blocks: text_encoder, vae_encoder, controlnet_vae_encoder, denoise, and decode. Access them through the sub_blocks property.

The doc property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.

vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
print(vae_encoder_block.doc)

This block can be converted to a pipeline and run on its own with [~ModularPipelineBlocks.init_pipeline].

vae_encoder_pipe = vae_encoder_block.init_pipeline()

# Reuse the VAE we already loaded, we can reuse it with update_components() method
vae_encoder_pipe.update_components(vae=pipe.vae)

# Run just this block
image_latents = vae_encoder_pipe(image=input_image).image_latents
print(image_latents.shape)

It reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the Loading components guide.

Since blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map.

Compose new workflows

Let's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see Building Custom Blocks to create your own).

from diffusers.modular_pipelines import ModularPipelineBlocks

# Load a canny block from the Hub
canny_block = ModularPipelineBlocks.from_pretrained(
    "diffusers-internal-dev/canny-filtering",
    trust_remote_code=True,
)

print(canny_block.doc)
class CannyBlock

  Inputs:
      image (`Union[Image, ndarray]`):
          Image to compute canny filter on
      low_threshold (`int`, *optional*, defaults to 50):
          Low threshold for the canny filter.
      high_threshold (`int`, *optional*, defaults to 200):
          High threshold for the canny filter.
      ...

  Outputs:
      control_image (`PIL.Image`):
          Canny map for input image

Use get_workflow to extract the ControlNet workflow from [QwenImageAutoBlocks].

# Get the controlnet workflow that we want to work with
blocks = pipe.blocks.get_workflow("controlnet_text2image")
print(blocks.doc)
class SequentialPipelineBlocks

  Inputs:
      prompt (`str`):
          The prompt or prompts to guide image generation.
      control_image (`Image`):
          Control image for ControlNet conditioning.
      ...

It requires control_image as input. After inserting the canny block, the pipeline will accept a regular image instead.

# and insert canny at the beginning
blocks.sub_blocks.insert("canny", canny_block, 0)

# Check the updated structure: CannyBlock is now listed as first sub-block
print(blocks)
# Check the updated doc: notice the pipeline now takes "image" as input
# even though it's a controlnet pipeline, because canny preprocesses it into control_image
print(blocks.doc)
class SequentialPipelineBlocks

  Inputs:
      image (`Union[Image, ndarray]`):
          Image to compute canny filter on
      low_threshold (`int`, *optional*, defaults to 50):
          Low threshold for the canny filter.
      high_threshold (`int`, *optional*, defaults to 200):
          High threshold for the canny filter.
      prompt (`str`):
          The prompt or prompts to guide image generation.
      ...

Now the pipeline takes image as input - the canny block will preprocess it into control_image automatically.

Create a pipeline from the modified blocks and load a ControlNet model. We use [ComponentsManager] to enable CPU offloading for reduced memory usage (learn more in the ComponentsManager guide).

from diffusers import ComponentsManager

manager = ComponentsManager()
manager.enable_auto_cpu_offload(device="cuda:0")

pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)

pipeline.load_components(torch_dtype=torch.bfloat16)

# Load the ControlNet model
controlnet_spec = pipeline.get_component_spec("controlnet")
controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union"
controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16)
pipeline.update_components(controlnet=controlnet)

Now run the pipeline - the canny block preprocesses the image for ControlNet.

from diffusers.utils import load_image

prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")

output = pipeline(
    prompt=prompt,
    image=image,
).images[0]
output

Next steps

Learn how to create your own blocks with custom logic in the Building Custom Blocks guide.

Use ComponentsManager to share models across multiple pipelines and manage memory efficiently.

Connect modular pipelines to Mellon, a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in Mellon guide.