1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into add-attentionmixin-qwen-image

This commit is contained in:
Sayak Paul
2025-08-22 21:04:20 +05:30
committed by GitHub
3 changed files with 82 additions and 6 deletions

View File

@@ -316,6 +316,67 @@ if integrity_checker.test_image(image_):
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
```
### Kontext Inpainting
`FluxKontextInpaintPipeline` enables image modification within a fixed mask region. It currently supports both text-based conditioning and image-reference conditioning.
<hfoptions id="kontext-inpaint">
<hfoption id="text-only">
```python
import torch
from diffusers import FluxKontextInpaintPipeline
from diffusers.utils import load_image
prompt = "Change the yellow dinosaur to green one"
img_url = (
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
)
mask_url = (
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
)
source = load_image(img_url)
mask = load_image(mask_url)
pipe = FluxKontextInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
image.save("kontext_inpainting_normal.png")
```
</hfoption>
<hfoption id="image conditioning">
```python
import torch
from diffusers import FluxKontextInpaintPipeline
from diffusers.utils import load_image
pipe = FluxKontextInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt = "Replace this ball"
img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
source = load_image(img_url)
mask = load_image(mask_url)
image_reference = load_image(image_reference_url)
mask = pipe.mask_processor.blur(mask, blur_factor=12)
image = pipe(
prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
).images[0]
image.save("kontext_inpainting_ref.png")
```
</hfoption>
</hfoptions>
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
@@ -646,3 +707,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxFillPipeline
- all
- __call__
## FluxKontextPipeline
[[autodoc]] FluxKontextPipeline
- all
- __call__
## FluxKontextInpaintPipeline
[[autodoc]] FluxKontextInpaintPipeline
- all
- __call__

View File

@@ -162,6 +162,9 @@ Take a look at the [Quantization](./quantization/overview) section for more deta
## Optimizations
> [!TIP]
> Optimization is dependent on hardware specs such as memory. Use this [Space](https://huggingface.co/spaces/diffusers/optimized-diffusers-code) to generate code examples that include all of Diffusers' available memory and speed optimization techniques for any model you're using.
Modern diffusion models are very large and have billions of parameters. The iterative denoising process is also computationally intensive and slow. Diffusers provides techniques for reducing memory usage and boosting inference speed. These techniques can be combined with quantization to optimize for both memory usage and inference speed.
### Memory usage

View File

@@ -28,10 +28,10 @@ from diffusers import (
)
from diffusers.pipelines.bria import BriaPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
@require_torch_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
@require_torch_accelerator
class BriaPipelineSlowTests(unittest.TestCase):
pipeline_class = BriaPipeline
repo_id = "briaai/BRIA-3.2"
@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
generator = torch.Generator(device="cpu").manual_seed(seed)