mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Flux Control(Depth/Canny) + Inpaint (#10192)
* flux_control_inpaint - failing test_flux_different_prompts * removing test_flux_different_prompts? * fix style * fix from PR comments * fix style * reducing guidance_scale in demo * Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py Co-authored-by: hlky <hlky@hlky.ac> * make * prepare_latents is not copied from * update docs * typos --------- Co-authored-by: affromero <ubuntu@ip-172-31-17-146.ec2.internal> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -400,6 +400,8 @@
|
||||
title: DiT
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
|
||||
89
docs/source/en/api/pipelines/control_flux_inpaint.md
Normal file
89
docs/source/en/api/pipelines/control_flux_inpaint.md
Normal file
@@ -0,0 +1,89 @@
|
||||
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FluxControlInpaint
|
||||
|
||||
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
|
||||
|
||||
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
|
||||
|
||||
| Control type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
|
||||
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxControlInpaintPipeline
|
||||
from diffusers.models.transformers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
pipe = FluxControlInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Depth-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# use following lines if you have GPU constraints
|
||||
# ---------------------------------------------------------------
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.enable_model_cpu_offload()
|
||||
# ---------------------------------------------------------------
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "a blue robot singing opera with human-like expressions"
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
head_mask = np.zeros_like(image)
|
||||
head_mask[65:580,300:642] = 255
|
||||
mask_image = Image.fromarray(head_mask)
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(image)[0].convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=30,
|
||||
strength=0.9,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
|
||||
```
|
||||
|
||||
## FluxControlInpaintPipeline
|
||||
[[autodoc]] FluxControlInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## FluxPipelineOutput
|
||||
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
|
||||
@@ -277,6 +277,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
@@ -765,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
|
||||
@@ -128,6 +128,7 @@ else:
|
||||
]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -539,6 +540,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
||||
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
|
||||
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
|
||||
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
|
||||
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_flux import FluxPipeline
|
||||
from .pipeline_flux_control import FluxControlPipeline
|
||||
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
|
||||
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
|
||||
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
||||
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
||||
|
||||
1141
src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
Normal file
1141
src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -392,6 +392,21 @@ class FluxControlImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
215
tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
Normal file
215
tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxControlInpaintPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=8,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=32,
|
||||
axes_dims_rope=[4, 4, 8],
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
image = Image.new("RGB", (8, 8), 0)
|
||||
control_image = Image.new("RGB", (8, 8), 0)
|
||||
mask_image = Image.new("RGB", (8, 8), 255)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"control_image": control_image,
|
||||
"generator": generator,
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"strength": 0.8,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 30.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
# def test_flux_different_prompts(self):
|
||||
# pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
# inputs = self.get_dummy_inputs(torch_device)
|
||||
# output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
# inputs = self.get_dummy_inputs(torch_device)
|
||||
# inputs["prompt_2"] = "a different prompt"
|
||||
# output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
# max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# # Outputs should be different here
|
||||
# # For some reasons, they don't show large differences
|
||||
# assert max_diff > 1e-6
|
||||
|
||||
def test_flux_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
device=torch_device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
Reference in New Issue
Block a user