mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
developer_guide -> end-to-end guide
This commit is contained in:
@@ -10,35 +10,40 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Developer Guide: Building with Modular Diffusers
|
||||
# End-to-End Developer Guide: Building with Modular Diffusers
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
|
||||
|
||||
We'll also demonstrate the 3-step framework process we use for implementing new basic pipelines in the modular system.
|
||||
We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
|
||||
|
||||
#### 1. **Start with an existing pipeline as a base**
|
||||
- Identify which existing pipeline is most similar to your target
|
||||
- Determine what part of the pipeline need modification
|
||||
1. **Start with an existing pipeline as a base**
|
||||
- Identify which existing pipeline is most similar to the one you want to implement
|
||||
- Determine what part of the pipeline needs modification
|
||||
|
||||
#### 2. **Build a working pipeline structure first**
|
||||
2. **Build a working pipeline structure first**
|
||||
- Assemble the complete pipeline structure
|
||||
- Use existing blocks wherever possible
|
||||
- For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
|
||||
|
||||
#### 3. **Set up an example and test incrementally**
|
||||
3. **Set up an example**
|
||||
- Create a simple inference script with expected inputs/outputs
|
||||
- Test incrementally as you implement changes
|
||||
|
||||
4. **Implement your custom logic and test incrementally**
|
||||
- Add the custom logics the blocks you want to change
|
||||
- Test incrementally, and inspect pipeline states and debug as needed
|
||||
|
||||
Let's see how this works with the Differential Diffusion example.
|
||||
|
||||
|
||||
## Differential Diffusion Pipeline
|
||||
|
||||
### Start with an existing pipeline
|
||||
|
||||
Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
>>> IMAGE2IMAGE_BLOCKS = InsertableDict([
|
||||
... ("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
... ("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
@@ -46,12 +51,12 @@ Differential diffusion (https://differential-diffusion.github.io/) is an image-t
|
||||
... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
... ("denoise", StableDiffusionXLDenoiseLoop),
|
||||
... ("denoise", StableDiffusionXLDenoiseStep),
|
||||
... ("decode", StableDiffusionXLDecodeStep)
|
||||
... ])
|
||||
```
|
||||
|
||||
Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 loop blocks (more on SequentialLoopBlocks [here](https://colab.research.google.com/drive/1iVRjy_tOfmmm4gd0iVe0_Rl3c6cBzVqi?usp=sharing))
|
||||
Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
|
||||
|
||||
```py
|
||||
>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
|
||||
@@ -59,7 +64,7 @@ Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 l
|
||||
```
|
||||
|
||||
```out
|
||||
StableDiffusionXLDenoiseLoop(
|
||||
StableDiffusionXLDenoiseStep(
|
||||
Class: StableDiffusionXLDenoiseLoopWrapper
|
||||
|
||||
Description: Denoise step that iteratively denoise the latents.
|
||||
@@ -68,7 +73,7 @@ StableDiffusionXLDenoiseLoop(
|
||||
- `StableDiffusionXLLoopBeforeDenoiser`
|
||||
- `StableDiffusionXLLoopDenoiser`
|
||||
- `StableDiffusionXLLoopAfterDenoiser`
|
||||
|
||||
This block supports both text2img and img2img tasks.
|
||||
|
||||
|
||||
Components:
|
||||
@@ -76,7 +81,7 @@ StableDiffusionXLDenoiseLoop(
|
||||
guider (`ClassifierFreeGuidance`)
|
||||
unet (`UNet2DConditionModel`)
|
||||
|
||||
Blocks:
|
||||
Sub-Blocks:
|
||||
[0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
|
||||
Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
|
||||
|
||||
@@ -89,18 +94,23 @@ StableDiffusionXLDenoiseLoop(
|
||||
)
|
||||
```
|
||||
|
||||
Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
|
||||
|
||||
Img2img diffusion pipeline adds the same noise level across all pixels based on a single strength parameter, however, differential diffusion uses a change map where each pixel value represents when that region should start denoising. Regions with lower change map values get "frozen" earlier in the denoising process by replacing them with noised original latents, effectively giving them fewer denoising steps and thus preserving more of the original image.
|
||||
Therefore, the key differences when it comes to pipeline implementation would be:
|
||||
1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
|
||||
2. The `denoise` step (which selectively applies denoising based on the change map)
|
||||
3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
|
||||
|
||||
It has a different `prepare_latents` step and `denoise` step. At `parepare_latents` step, it prepares the change map and pre-computes the original noised latents for all timesteps. At each timestep during the denoising process, it selectively applies denoising based on the change map. Additionally, diff-diff does not use the `strengh` parameter, so its `set_timesteps` step is different from the one in image-to-image, but same as `set_timesteps` in text-to-image workflow.
|
||||
To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
|
||||
|
||||
So, to implement the differential diffusion pipeline, we can use pipeline blocks from image-to-image and text-to-image workflow, and change the `prepare_latents` step and the `denoise` step (more specifically, we only need to change the first part of `denoise` step where we prepare the latent input for the denoiser model).
|
||||
|
||||
Differential diffusion shares exact same pipeline structure as img2img. Here is a flowchart that puts the changes we need to make into the context of the pipeline structure.
|
||||
Here's a flowchart showing the pipeline structure and the changes we need to make:
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
### Build a Working Pipeline Structure
|
||||
|
||||
ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
|
||||
|
||||
```py
|
||||
@@ -114,10 +124,10 @@ ok now we've identified the blocks to modify, let's build the pipeline skeleton
|
||||
... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
|
||||
```
|
||||
|
||||
`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseLoop`.
|
||||
`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
|
||||
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
```
|
||||
@@ -128,18 +138,20 @@ Now we can put together our differential diffusion pipeline.
|
||||
>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseLoop
|
||||
>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
|
||||
>>>
|
||||
>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
|
||||
>>> print(dd_blocks)
|
||||
>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
|
||||
```
|
||||
|
||||
ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple exapmple so we can run the pipeline as we build it. diff-diff use same components as SDXL so we can fetch the models from a regular SDXL repo.
|
||||
### Set up an example
|
||||
|
||||
ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
|
||||
|
||||
```py
|
||||
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
>>> dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
>>> dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
@@ -167,12 +179,17 @@ We will use this example script:
|
||||
If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
|
||||
and you would get the same result as the original img2img pipeline.
|
||||
|
||||
### implement your custom logic and test incrementally
|
||||
|
||||
Let's modify the pipeline so that we can get expected result with this example script.
|
||||
|
||||
We'll start with the `prepare_latents` step, as it is the first step that gets called right after the `input` step. Let's first apply changes in inputs/outputs/components. The main changes are:
|
||||
- new input `diffdiff_map`
|
||||
- new intermediates inputs `num_inference_steps` and `timestesp`. Both variables are already created in `set_timesteps` block, we can now need to use them in `prepare_latents` step.
|
||||
- A new component `mask_processor` to process the `diffdiff_map`
|
||||
We'll start with the `prepare_latents` step. The main changes are:
|
||||
- Requires a new user input `diffdiff_map`
|
||||
- Requires new component `mask_processor` to process the `diffdiff_map`
|
||||
- Requires new intermediate inputs:
|
||||
- Need `timestep` instead of `latent_timestep` to precompute all the latents
|
||||
- Need `num_inference_steps` to create the `diffdiff_masks`
|
||||
- create a new output `diffdiff_masks` and `original_latents`
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -182,7 +199,7 @@ e.g. after we added `diffdiff_map` as an input in this step, we can run `print(d
|
||||
|
||||
</Tip>
|
||||
|
||||
Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. We also need to list them as intermediates outputs so the we can use them in the `denoise` step later.
|
||||
Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -190,161 +207,91 @@ Once we make sure all the variables we need are available in the block state, we
|
||||
|
||||
</Tip>
|
||||
|
||||
This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up with :
|
||||
Here are the key changes we made to implement differential diffusion:
|
||||
|
||||
**1. Modified `prepare_latents` step:**
|
||||
```diff
|
||||
- class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
+ class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
- "Step that prepares the latents for the image-to-image generation process"
|
||||
+ "Step that prepares the latents for the differential diffusion generation process"
|
||||
)
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
+ InputParam("diffdiff_map", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec(
|
||||
+ "mask_processor",
|
||||
+ VaeImageProcessor,
|
||||
+ config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}),
|
||||
+ default_creation_method="from_config",
|
||||
+ )
|
||||
]
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("timesteps", type_hint=torch.Tensor),
|
||||
+ InputParam("num_inference_steps", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
+ InputParam("diffdiff_map",required=True),
|
||||
]
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."),
|
||||
+ InputParam("timesteps",type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."),
|
||||
+ InputParam("num_inference_steps", type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor, description="The masks used for the differential diffusion denoising process"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.dtype = components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
block_state.add_noise = True if block_state.denoising_start is None else False
|
||||
+ components.scheduler.set_begin_index(None)
|
||||
|
||||
if block_state.latents is None:
|
||||
block_state.latents = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
- block_state.latent_timestep,
|
||||
+ block_state.timesteps,
|
||||
block_state.batch_size,
|
||||
block_state.num_images_per_prompt,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.add_noise,
|
||||
)
|
||||
+
|
||||
+ latent_height = block_state.image_latents.shape[-2]
|
||||
+ latent_width = block_state.image_latents.shape[-1]
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+
|
||||
+ diffdiff_map = diffdiff_map.squeeze(0).to(block_state.device)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(block_state.device)
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
|
||||
self.add_block_state(state, block_state)
|
||||
def __call__(self, components, state: PipelineState):
|
||||
# ... existing logic ...
|
||||
+ # Process change map and create masks
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
```
|
||||
|
||||
Now let's modify `before_denoiser` step, we use diff-diff map to freeze certain regions in the latents before each denoising step.
|
||||
|
||||
**2. Modified `before_denoiser` step:**
|
||||
```diff
|
||||
class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
- "step within the denoising loop that prepare the latent input for the denoiser"
|
||||
+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
)
|
||||
|
||||
+ @property
|
||||
+ def inputs(self) -> List[Tuple[str, Any]]:
|
||||
+ return [
|
||||
+ InputParam("denoising_start"),
|
||||
+ ]
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
|
||||
),
|
||||
+ InputParam(
|
||||
+ "original_latents",
|
||||
+ type_hint=torch.Tensor,
|
||||
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
|
||||
+ ),
|
||||
+ InputParam(
|
||||
+ "diffdiff_masks",
|
||||
+ type_hint=torch.Tensor,
|
||||
+ description="The masks used for the differential diffusion denoising process, can be generated in prepare_latent step."
|
||||
+ ),
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("original_latents", type_hint=torch.Tensor),
|
||||
InputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state, i, t):
|
||||
+ # diff diff
|
||||
+ if i == 0 and block_state.denoising_start is None:
|
||||
+ block_state.latents = block_state.original_latents[:1]
|
||||
+ else:
|
||||
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0)
|
||||
+ # cast mask to the same type as latents etc
|
||||
+ block_state.mask = block_state.mask.to(block_state.latents.dtype)
|
||||
+ block_state.mask = block_state.mask.unsqueeze(1) # fit shape
|
||||
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
+ # end diff diff
|
||||
|
||||
+ # expand the latents if we are doing classifier free guidance
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
|
||||
return components, block_state
|
||||
# Apply differential diffusion logic
|
||||
if i == 0 and block_state.denoising_start is None:
|
||||
block_state.latents = block_state.original_latents[:1]
|
||||
else:
|
||||
block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
|
||||
block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
|
||||
# ... rest of existing logic ...
|
||||
```
|
||||
|
||||
That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 You can inspect the pipeline you built with `print()`
|
||||
|
||||
</Tip>
|
||||
Now we use the process we've prepred in step2 to build the pipeline and inspect it.
|
||||
|
||||
|
||||
```out
|
||||
```py
|
||||
>> dd_pipeline
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
@@ -392,7 +339,7 @@ SequentialPipelineBlocks(
|
||||
[5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
|
||||
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
|
||||
|
||||
[6] denoise (SDXLDiffDiffDenoiseLoop)
|
||||
[6] denoise (SDXLDiffDiffDenoiseStep)
|
||||
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
|
||||
|
||||
[7] decode (StableDiffusionXLDecodeStep)
|
||||
@@ -401,50 +348,20 @@ SequentialPipelineBlocks(
|
||||
)
|
||||
```
|
||||
|
||||
Now if you run the example we prepared earlier, you should see an apple with its right half transformed into a green pear.
|
||||
Run the example now, you should see an apple with its right half transformed into a green pear.
|
||||
|
||||

|
||||
|
||||
|
||||
## Adding IP-adapter
|
||||
|
||||
We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](TODO)
|
||||
We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks)
|
||||
|
||||
Let's create IP-adapter block:
|
||||
We talked about how to add IP-adapter into your workflow in the [getting-started guide](https://huggingface.co/docs/diffusers/modular_diffusers/quicktour#ip-adapter). Let's just go ahead to create the IP-adapter block.
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
|
||||
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
>>> print(ip_adapter_block)
|
||||
```
|
||||
|
||||
It has 4 components: `unet` and `guider` are already used in diff-diff, but it also has two new ones: `image_encoder` and `feature_extractor`
|
||||
|
||||
```out
|
||||
ip adapter block: StableDiffusionXLAutoIPAdapterStep(
|
||||
Class: AutoPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: {'ip_adapter_image'}
|
||||
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description: Run IP Adapter step if `ip_adapter_image` is provided.
|
||||
|
||||
|
||||
Components:
|
||||
image_encoder (`CLIPVisionModelWithProjection`)
|
||||
feature_extractor (`CLIPImageProcessor`)
|
||||
unet (`UNet2DConditionModel`)
|
||||
guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Blocks:
|
||||
• ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep)
|
||||
Description: IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
|
||||
@@ -521,7 +438,7 @@ SequentialPipelineBlocks(
|
||||
[6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
|
||||
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
|
||||
|
||||
[7] denoise (SDXLDiffDiffDenoiseLoop)
|
||||
[7] denoise (SDXLDiffDiffDenoiseStep)
|
||||
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
|
||||
|
||||
[8] decode (StableDiffusionXLDecodeStep)
|
||||
@@ -573,14 +490,14 @@ From looking at the code workflow: differential diffusion only modifies the "bef
|
||||
|
||||
Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
|
||||
|
||||
With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseLoop`:
|
||||
With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseStep`:
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
>>>
|
||||
>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseLoop()
|
||||
>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
>>> # print(controlnet_denoise)
|
||||
```
|
||||
|
||||
@@ -588,33 +505,32 @@ We provide a auto controlnet input block that you can directly put into your wor
|
||||
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLControlNetAutoInput
|
||||
>>> control_input_block = StableDiffusionXLControlNetAutoInput()
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets import StableDiffusionXLAutoControlNetInputStep
|
||||
>>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
|
||||
>>> print(control_input_block)
|
||||
```
|
||||
|
||||
```out
|
||||
StableDiffusionXLControlNetAutoInput(
|
||||
StableDiffusionXLAutoControlNetInputStep(
|
||||
Class: AutoPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: {'control_image', 'control_mode'}
|
||||
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`).
|
||||
Trigger Inputs: ['control_image', 'control_mode']
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description: Controlnet Input step that prepare the controlnet input.
|
||||
This is an auto pipeline block that works for both controlnet and controlnet_union.
|
||||
- `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
|
||||
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided.
|
||||
(it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
|
||||
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
|
||||
|
||||
|
||||
Components:
|
||||
controlnet (`ControlNetUnionModel`)
|
||||
control_image_processor (`VaeImageProcessor`)
|
||||
|
||||
Blocks:
|
||||
Sub-Blocks:
|
||||
• controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
|
||||
Description: step that prepares inputs for the ControlNetUnion model
|
||||
|
||||
@@ -622,6 +538,7 @@ StableDiffusionXLControlNetAutoInput(
|
||||
Description: step that prepare inputs for controlnet
|
||||
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
|
||||
@@ -655,12 +572,12 @@ Let's assemble the blocks and run an example using controlnet + differential dif
|
||||
... )[0]
|
||||
```
|
||||
|
||||
Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiffDenoiseLoop` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
|
||||
Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
|
||||
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
|
||||
... block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop]
|
||||
... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
|
||||
... block_names = ["controlnet_denoise", "denoise"]
|
||||
... block_trigger_inputs = ["controlnet_cond", None]
|
||||
```
|
||||
@@ -669,7 +586,7 @@ Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiff
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. We won't go into too much detail about `AutoPipelineBlocks` in this section, but you can read more about it [here](TODO).
|
||||
Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
|
||||
|
||||
</Tip>
|
||||
|
||||
Reference in New Issue
Block a user