1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add VisualCloze (#11377)

* VisualCloze

* style quality

* add docs

* add docs

* typo

* Update docs/source/en/api/pipelines/visualcloze.md

* delete einops

* style quality

* Update src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py

* reorg

* refine doc

* style quality

* typo

* typo

* Update src/diffusers/image_processor.py

* add comment

* test

* style

* Modified based on review

* style

* restore image_processor

* update example url

* style

* fix-copies

* VisualClozeGenerationPipeline

* combine

* tests docs

* remove VisualClozeUpsamplingPipeline

* style

* quality

* test examples

* quality style

* typo

* make fix-copies

* fix test_callback_cfg and test_save_load_dduf in VisualClozePipelineFastTests

* add EXAMPLE_DOC_STRING to VisualClozeGenerationPipeline

* delete maybe_free_model_hooks from pipeline_visualcloze_combined

* Apply suggestions from code review

* fix test_save_load_local test; add reason for skipping cfg test

* more save_load test fixes

* fix tests in generation pipeline tests
This commit is contained in:
Zhong-Yu Li
2025-05-13 05:16:51 +08:00
committed by GitHub
parent 98cc6d05e4
commit 4f438de35a
13 changed files with 2694 additions and 0 deletions

View File

@@ -575,6 +575,8 @@
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/visualcloze
title: VisualCloze
- local: api/pipelines/wan
title: Wan
- local: api/pipelines/wuerstchen

View File

@@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
| [Value-guided planning](value_guided_sampling) | value guided sampling |
| [Wuerstchen](wuerstchen) | text2image |
| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting |
## DiffusionPipeline

View File

@@ -0,0 +1,300 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# VisualCloze
[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:
1. Support for various in-domain tasks
2. Generalization to unseen tasks through in-context learning
3. Unify multiple tasks into one step and generate both target image and intermediate results
4. Support reverse-engineering conditions from target images
## Overview
The abstract from the paper is:
*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*
## Inference
### Model loading
VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://arxiv.org/abs/2108.01073) to enable high-resolution image synthesis.
The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.
### Input Specifications
#### Task and Content Prompts
- Task prompt: Required to describe the generation task intention
- Content prompt: Optional description or caption of the target image
- When content prompt is not needed, pass `None`
- For batch inference, pass `List[str|None]`
#### Image Input Format
- Format: `List[List[Image|None]]`
- Structure:
- All rows except the last represent in-context examples
- Last row represents the current query (target image set to `None`)
- For batch inference, pass `List[List[List[Image|None]]]`
#### Resolution Control
- Default behavior:
- Initial generation in the first stage: area of ${pipe.resolution}^2$
- Upsampling in the second stage: 3x factor
- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters
### Examples
For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.
#### Example for mask2image
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape.
The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible.
Its plumage is a mix of dark brown and golden hues, with intricate feather details.
The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere.
The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field,
soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background,
tranquil, majestic, wildlife photography."""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=1344,
upsampling_height=768,
upsampling_strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Example for edge-detection
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),
],
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task."
content_prompt = ""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=864,
upsampling_height=1152,
upsampling_strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Example for subject-driven generation
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),
],
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object,
[IMAGE2] depth map revealing gray-toned spatial layers and results in
[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail."""
content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring,
this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene."""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=1024,
upsampling_height=1024,
upsampling_strength=0.2,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Utilize each pipeline independently
```python
import torch
from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
from diffusers.utils import load_image
from PIL import Image
pipe = VisualClozeGenerationPipeline.from_pretrained(
"VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image_paths = [
# in-context examples
[
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
),
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
),
],
# query with the target image
[
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
),
None, # No image needed for the target image
],
]
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
# Stage 1: Generate initial image
image = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0][0]
# Stage 2 (optional): Upsample the generated image
pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
pipe_upsample.to("cuda")
mask_image = Image.new("RGB", image.size, (255, 255, 255))
image = pipe_upsample(
image=image,
mask_image=mask_image,
prompt=content_prompt,
width=1344,
height=768,
strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save("visualcloze.png")
```
## VisualClozePipeline
[[autodoc]] VisualClozePipeline
- all
- __call__
## VisualClozeGenerationPipeline
[[autodoc]] VisualClozeGenerationPipeline
- all
- __call__

View File

@@ -520,6 +520,8 @@ else:
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
@@ -1100,6 +1102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,

View File

@@ -281,6 +281,7 @@ else:
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -727,6 +728,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,

View File

@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"]
_import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_visualcloze_combined import VisualClozePipeline
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,444 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline
from ..flux.pipeline_output import FluxPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import VisualClozePipeline
>>> from diffusers.utils import load_image
>>> image_paths = [
... # in-context examples
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
... ),
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
... ),
... ],
... # query with the target image
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
... ),
... None, # No image needed for the target image
... ],
... ]
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
>>> pipe = VisualClozePipeline.from_pretrained(
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(
... task_prompt=task_prompt,
... content_prompt=content_prompt,
... image=image_paths,
... upsampling_width=1344,
... upsampling_height=768,
... upsampling_strength=0.4,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0][0]
>>> image.save("visualcloze.png")
```
"""
class VisualClozePipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
The VisualCloze pipeline for image generation with visual context. Reference:
https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual
in-context examples.
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
resolution (`int`, *optional*, defaults to 384):
The resolution of each image when concatenating images from the query and in-context examples.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.generation_pipe = VisualClozeGenerationPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
resolution=resolution,
)
self.upsampling_pipe = VisualClozeUpsamplingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
def check_inputs(
self,
image,
task_prompt,
content_prompt,
upsampling_height,
upsampling_width,
strength,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly"
)
if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# Validate prompt inputs
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
if task_prompt is None and content_prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
# Validate prompt types and consistency
if task_prompt is None:
raise ValueError("`task_prompt` is missing.")
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
raise ValueError(
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
f"got {type(task_prompt)} and {type(content_prompt)}"
)
if len(content_prompt) != len(task_prompt):
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
for sample in image:
if not isinstance(sample, list) or not isinstance(sample[0], list):
raise ValueError("Each sample in the batch must have a 2D list of images.")
if len({len(row) for row in sample}) != 1:
raise ValueError("Each in-context example and query should contain the same number of images.")
if not any(img is None for img in sample[-1]):
raise ValueError("There are no targets in the query, which should be represented as None.")
for row in sample[:-1]:
if any(img is None for img in row):
raise ValueError("Images are missing in in-context examples.")
# Validate embeddings
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
# Validate sequence length
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
task_prompt: Union[str, List[str]] = None,
content_prompt: Union[str, List[str]] = None,
image: Optional[torch.FloatTensor] = None,
upsampling_height: Optional[int] = None,
upsampling_width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
upsampling_strength: float = 1.0,
):
r"""
Function invoked when calling the VisualCloze pipeline for generation.
Args:
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
default, the image is upsampled by a factor of three, and the base resolution is determined by the
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
specified, the other will be automatically set based on the aspect ratio.
upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
default, the image is upsampled by a factor of three, and the base resolution is determined by the
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
specified, the other will be automatically set based on the aspect ratio.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 30.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
upsampling_strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and
1. The generated image is used as a starting point and more noise is added the higher the
`upsampling_strength`. The number of denoising steps depends on the amount of noise initially added.
When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full
number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and
output the results at the resolution of `self.resolution`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
generation_output = self.generation_pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
joint_attention_kwargs=joint_attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
output_type=output_type if upsampling_strength == 0 else "pil",
)
if upsampling_strength == 0:
if not return_dict:
return (generation_output,)
return FluxPipelineOutput(images=generation_output)
# Upsampling the generated images
# 1. Prepare the input images and prompts
if not isinstance(content_prompt, (list)):
content_prompt = [content_prompt]
n_target_per_sample = []
upsampling_image = []
upsampling_mask = []
upsampling_prompt = []
upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else []
for i in range(len(generation_output.images)):
n_target_per_sample.append(len(generation_output.images[i]))
for image in generation_output.images[i]:
upsampling_image.append(image)
upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255)))
upsampling_prompt.append(
content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else ""
)
if not isinstance(generator, (torch.Generator,)):
upsampling_generator.append(generator[i % len(content_prompt)])
# 2. Apply the denosing loop
upsampling_output = self.upsampling_pipe(
prompt=upsampling_prompt,
image=upsampling_image,
mask_image=upsampling_mask,
height=upsampling_height,
width=upsampling_width,
strength=upsampling_strength,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
generator=upsampling_generator,
output_type=output_type,
joint_attention_kwargs=joint_attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
image = upsampling_output.images
output = []
if output_type == "pil":
# Each sample in the batch may have multiple output images. When returning as PIL images,
# these images cannot be concatenated. Therefore, for each sample,
# a list is used to represent all the output images.
output = []
start = 0
for n in n_target_per_sample:
output.append(image[start : start + n])
start += n
else:
output = image
if not return_dict:
return (output,)
return FluxPipelineOutput(images=output)

View File

@@ -0,0 +1,952 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps
from ..flux.pipeline_output import FluxPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .visualcloze_utils import VisualClozeProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
>>> from diffusers.utils import load_image
>>> from PIL import Image
>>> image_paths = [
... # in-context examples
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
... ),
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
... ),
... ],
... # query with the target image
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
... ),
... None, # No image needed for the target image
... ],
... ]
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
>>> pipe = VisualClozeGenerationPipeline.from_pretrained(
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(
... task_prompt=task_prompt,
... content_prompt=content_prompt,
... image=image_paths,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0][0]
>>> # optional, upsampling the generated image
>>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
>>> pipe_upsample.to("cuda")
>>> mask_image = Image.new("RGB", image.size, (255, 255, 255))
>>> image = pipe_upsample(
... image=image,
... mask_image=mask_image,
... prompt=content_prompt,
... width=1344,
... height=768,
... strength=0.4,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0]
>>> image.save("visualcloze.png")
```
"""
class VisualClozeGenerationPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
The VisualCloze pipeline for image generation with visual context. Reference:
https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual
in-context examples.
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
resolution (`int`, *optional*, defaults to 384):
The resolution of each image when concatenating images from the query and in-context examples.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.resolution = resolution
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VisualClozeProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt(
self,
layout_prompt: Union[str, List[str]],
task_prompt: Union[str, List[str]],
content_prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
layout_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the number of in-context examples and the number of images involved in
the task.
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
if isinstance(layout_prompt, str):
layout_prompt = [layout_prompt]
task_prompt = [task_prompt]
content_prompt = [content_prompt]
def _preprocess(prompt, content=False):
if prompt is not None:
return f"The last image of the last row depicts: {prompt}" if content else prompt
else:
return ""
prompt = [
f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip()
for i in range(len(layout_prompt))
]
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
image,
task_prompt,
content_prompt,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# Validate prompt inputs
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
if task_prompt is None and content_prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
# Validate prompt types and consistency
if task_prompt is None:
raise ValueError("`task_prompt` is missing.")
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
raise ValueError(
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
f"got {type(task_prompt)} and {type(content_prompt)}"
)
if len(content_prompt) != len(task_prompt):
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
for sample in image:
if not isinstance(sample, list) or not isinstance(sample[0], list):
raise ValueError("Each sample in the batch must have a 2D list of images.")
if len({len(row) for row in sample}) != 1:
raise ValueError("Each in-context example and query should contain the same number of images.")
if not any(img is None for img in sample[-1]):
raise ValueError("There are no targets in the query, which should be represented as None.")
for row in sample[:-1]:
if any(img is None for img in row):
raise ValueError("Images are missing in in-context examples.")
# Validate embeddings
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
# Validate sequence length
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype):
latent_image_ids = []
for idx, img in enumerate(image, start=1):
img = img.squeeze(0)
channels, height, width = img.shape
num_patches_h = height // vae_scale_factor // 2
num_patches_w = width // vae_scale_factor // 2
patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype)
patch_ids[..., 0] = idx
patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None]
patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :]
patch_ids = patch_ids.reshape(-1, 3)
latent_image_ids.append(patch_ids)
return torch.cat(latent_image_ids, dim=0)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, sizes, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
start = 0
unpacked_latents = []
for i in range(len(sizes)):
cur_size = sizes[i]
height = cur_size[0][0] // vae_scale_factor
width = sum([size[1] for size in cur_size]) // vae_scale_factor
end = start + (height * width) // 4
cur_latents = latents[:, start:end]
cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5)
cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width)
unpacked_latents.append(cur_latents)
start = end
return unpacked_latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype):
"""Helper function to prepare latents for a single batch."""
# Concatenate images and masks along width dimension
image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image]
mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask]
# Generate latent image IDs
latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype)
# For initial encoding, use actual images
image_latent = [self._encode_vae_image(img, gen) for img in image]
masked_image_latent = [img.clone() for img in image_latent]
for i in range(len(image_latent)):
# Rearrange latents and masks for patch processing
num_channels_latents, height, width = image_latent[i].shape[1:]
image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width)
masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width)
# Rearrange masks for patch processing
num_channels_latents, height, width = mask[i].shape[1:]
mask[i] = mask[i].view(
1,
num_channels_latents,
height // vae_scale_factor,
vae_scale_factor,
width // vae_scale_factor,
vae_scale_factor,
)
mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4)
mask[i] = mask[i].reshape(
1,
num_channels_latents * (vae_scale_factor**2),
height // vae_scale_factor,
width // vae_scale_factor,
)
mask[i] = self._pack_latents(
mask[i],
1,
num_channels_latents * (vae_scale_factor**2),
height // vae_scale_factor,
width // vae_scale_factor,
)
# Concatenate along batch dimension
image_latent = torch.cat(image_latent, dim=1)
masked_image_latent = torch.cat(masked_image_latent, dim=1)
mask = torch.cat(mask, dim=1)
return image_latent, masked_image_latent, mask, latent_image_ids
def prepare_latents(
self,
input_image,
input_mask,
timestep,
batch_size,
dtype,
device,
generator,
vae_scale_factor,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# Process each batch
masked_image_latents = []
image_latents = []
masks = []
latent_image_ids = []
for i in range(len(input_image)):
_image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents(
input_image[i],
input_mask[i],
generator if isinstance(generator, torch.Generator) else generator[i],
vae_scale_factor,
device,
dtype,
)
masked_image_latents.append(_masked_image_latent)
image_latents.append(_image_latent)
masks.append(_mask)
latent_image_ids.append(_latent_image_ids)
# Concatenate all batches
masked_image_latents = torch.cat(masked_image_latents, dim=0)
image_latents = torch.cat(image_latents, dim=0)
masks = torch.cat(masks, dim=0)
# Handle batch size expansion
if batch_size > masked_image_latents.shape[0]:
if batch_size % masked_image_latents.shape[0] == 0:
# Expand batches by repeating
additional_image_per_prompt = batch_size // masked_image_latents.shape[0]
masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0)
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
masks = torch.cat([masks] * additional_image_per_prompt, dim=0)
else:
raise ValueError(
f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. "
"Batch sizes must be multiples of each other."
)
# Add noise to latents
noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype)
# Combine masked latents with masks
masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype)
return latents, masked_image_latents, latent_image_ids[0]
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
task_prompt: Union[str, List[str]] = None,
content_prompt: Union[str, List[str]] = None,
image: Optional[torch.FloatTensor] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the VisualCloze pipeline for generation.
Args:
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 30.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
image,
task_prompt,
content_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
processor_output = self.image_processor.preprocess(
task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor
)
# 2. Define call parameters
if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str):
batch_size = 1
elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list):
batch_size = len(processor_output["task_prompt"])
device = self._execution_device
# 3. Prepare prompt embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
layout_prompt=processor_output["layout_prompt"],
task_prompt=processor_output["task_prompt"],
content_prompt=processor_output["content_prompt"],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare timesteps
# Calculate sequence length and shift factor
image_seq_len = sum(
(size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2)
for sample in processor_output["image_size"][0]
for size in sample
)
# Calculate noise schedule parameters
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
# Get timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
# 5. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, masked_image_latents, latent_image_ids = self.prepare_latents(
processor_output["init_image"],
processor_output["mask"],
latent_timestep,
batch_size * num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
vae_scale_factor=self.vae_scale_factor,
)
# Calculate warmup steps
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Prepare guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = torch.cat((latents, masked_image_latents), dim=2)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# Compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# Call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# XLA optimization
if XLA_AVAILABLE:
xm.mark_step()
# 7. Post-process the image
# Crop the target image
# Since the generated image is a concatenation of the conditional and target regions,
# we need to extract only the target regions based on their positions
image = []
if output_type == "latent":
image = latents
else:
for b in range(len(latents)):
cur_image_size = processor_output["image_size"][b % batch_size]
cur_target_position = processor_output["target_position"][b % batch_size]
cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1]
cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
cur_image = self.vae.decode(cur_latent, return_dict=False)[0]
cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0]
start = 0
cropped = []
for i, size in enumerate(cur_image_size[-1]):
if cur_target_position[i]:
if output_type == "pil":
cropped.append(cur_image.crop((start, 0, start + size[1], size[0])))
else:
cropped.append(cur_image[0 : size[0], start : start + size[1]])
start += size[1]
image.append(cropped)
if output_type != "pil":
image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)

View File

@@ -0,0 +1,251 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
class VisualClozeProcessor(VaeImageProcessor):
"""
Image processor for the VisualCloze pipeline.
This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and
mask generation.
Args:
resolution (int, optional):
Target resolution for processing images. Each image will be resized to this resolution before being
concatenated to avoid the out-of-memory error. Defaults to 384.
*args: Additional arguments passed to [~image_processor.VaeImageProcessor]
**kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor]
"""
def __init__(self, *args, resolution: int = 384, **kwargs):
super().__init__(*args, **kwargs)
self.resolution = resolution
def preprocess_image(
self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int
) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]:
"""
Preprocesses input images for the VisualCloze pipeline.
This function handles the preprocessing of input images by:
1. Resizing and cropping images to maintain consistent dimensions
2. Converting images to the Tensor format for the VAE
3. Normalizing pixel values
4. Tracking image sizes and positions of target images
Args:
input_images (List[List[Optional[Image.Image]]]):
A nested list of PIL Images where:
- Outer list represents different samples, including in-context examples and the query
- Inner list contains images for the task
- In the last row, condition images are provided and the target images are placed as None
vae_scale_factor (int):
The scale factor used by the VAE for resizing images
Returns:
Tuple containing:
- List[List[torch.Tensor]]: Preprocessed images in tensor format
- List[List[List[int]]]: Dimensions of each processed image [height, width]
- List[int]: Target positions indicating which images are to be generated
"""
n_samples, n_task_images = len(input_images), len(input_images[0])
divisible = 2 * vae_scale_factor
processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)]
resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)]
target_position: List[int] = []
# Process each sample
for i in range(n_samples):
# Determine size from first non-None image
for j in range(n_task_images):
if input_images[i][j] is not None:
aspect_ratio = input_images[i][j].width / input_images[i][j].height
target_area = self.resolution * self.resolution
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
new_w = max(new_w // divisible, 1) * divisible
new_h = max(new_h // divisible, 1) * divisible
resize_size[i] = (new_w, new_h)
break
# Process all images in the sample
for j in range(n_task_images):
if input_images[i][j] is not None:
target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1])
processed_images[i].append(target)
if i == n_samples - 1:
target_position.append(0)
else:
blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0))
processed_images[i].append(blank)
if i == n_samples - 1:
target_position.append(1)
# Ensure consistent width for multiple target images when there are multiple target images
if len(target_position) > 1 and sum(target_position) > 1:
new_w = resize_size[n_samples - 1][0] or 384
for i in range(len(processed_images)):
for j in range(len(processed_images[i])):
if processed_images[i][j] is not None:
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
new_w = int(new_w / 16) * 16
new_h = int(new_h / 16) * 16
processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w)
# Convert to tensors and normalize
image_sizes = []
for i in range(len(processed_images)):
image_sizes.append([[img.height, img.width] for img in processed_images[i]])
for j, image in enumerate(processed_images[i]):
image = self.pil_to_numpy(image)
image = self.numpy_to_pt(image)
image = self.normalize(image)
processed_images[i][j] = image
return processed_images, image_sizes, target_position
def preprocess_mask(
self, input_images: List[List[Image.Image]], target_position: List[int]
) -> List[List[torch.Tensor]]:
"""
Generate masks for the VisualCloze pipeline.
Args:
input_images (List[List[Image.Image]]):
Processed images from preprocess_image
target_position (List[int]):
Binary list marking the positions of target images (1 for target, 0 for condition)
Returns:
List[List[torch.Tensor]]:
A nested list of mask tensors (1 for target positions, 0 for condition images)
"""
mask = []
for i, row in enumerate(input_images):
if i == len(input_images) - 1: # Query row
row_masks = [
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position
]
else: # In-context examples
row_masks = [
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position
]
mask.append(row_masks)
return mask
def preprocess_image_upsampling(
self,
input_images: List[List[Image.Image]],
height: int,
width: int,
) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]:
"""Process images for the upsampling stage in the VisualCloze pipeline.
Args:
input_images: Input image to process
height: Target height
width: Target width
Returns:
Tuple of processed image and its size
"""
image = self.resize(input_images[0][0], height, width)
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt
image = self.normalize(image)
input_images[0][0] = image
image_sizes = [[[height, width]]]
return input_images, image_sizes
def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]:
return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]]
def get_layout_prompt(self, size: Tuple[int, int]) -> str:
layout_instruction = (
f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
)
return layout_instruction
def preprocess(
self,
task_prompt: Union[str, List[str]],
content_prompt: Union[str, List[str]],
input_images: Optional[List[List[List[Optional[str]]]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
upsampling: bool = False,
vae_scale_factor: int = 16,
) -> Dict:
"""Process visual cloze inputs.
Args:
task_prompt: Task description(s)
content_prompt: Content description(s)
input_images: List of images or None for the target images
height: Optional target height for upsampling stage
width: Optional target width for upsampling stage
upsampling: Whether this is in the upsampling processing stage
Returns:
Dictionary containing processed images, masks, prompts and metadata
"""
if isinstance(task_prompt, str):
task_prompt = [task_prompt]
content_prompt = [content_prompt]
input_images = [input_images]
output = {
"init_image": [],
"mask": [],
"task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))],
"content_prompt": content_prompt,
"layout_prompt": [],
"target_position": [],
"image_size": [],
}
for i in range(len(task_prompt)):
if upsampling:
layout_prompt = None
else:
layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0])))
if upsampling:
cur_processed_images, cur_image_size = self.preprocess_image_upsampling(
input_images[i], height=height, width=width
)
cur_mask = self.preprocess_mask_upsampling(cur_processed_images)
else:
cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image(
input_images[i], vae_scale_factor=vae_scale_factor
)
cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position)
output["target_position"].append(cur_target_position)
output["image_size"].append(cur_image_size)
output["init_image"].append(cur_processed_images)
output["mask"].append(cur_mask)
output["layout_prompt"].append(layout_prompt)
return output

View File

@@ -2792,6 +2792,36 @@ class VideoToVideoSDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class VisualClozeGenerationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VisualClozePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,344 @@
import random
import tempfile
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
import diffusers
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
floats_tensor,
require_accelerator,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = VisualClozePipeline
params = frozenset(
[
"task_prompt",
"content_prompt",
"upsampling_height",
"upsampling_width",
"guidance_scale",
"prompt_embeds",
"pooled_prompt_embeds",
"upsampling_strength",
]
)
batch_params = frozenset(["task_prompt", "content_prompt", "image"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=12,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=6,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[2, 2, 2],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"resolution": 32,
}
def get_dummy_inputs(self, device, seed=0):
# Create example images to simulate the input format required by VisualCloze
context_image = [
Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
for _ in range(2)
]
query_image = [
Image.fromarray(
floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
),
None,
]
# Create an image list that conforms to the VisualCloze input format
image = [
context_image, # In-Context example
query_image, # Query image
]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
"content_prompt": "A beautiful landscape with mountains and a lake",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"upsampling_height": 32,
"upsampling_width": 32,
"max_sequence_length": 77,
"output_type": "np",
"upsampling_strength": 0.4,
}
return inputs
def test_visualcloze_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["task_prompt"] = "A different task to perform."
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different
assert max_diff > 1e-6
def test_visualcloze_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.generation_pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.generation_pipe.vae_scale_factor * 2)
inputs.update({"upsampling_height": height, "upsampling_width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
def test_upsampling_strength(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
# Test different upsampling strengths
inputs["upsampling_strength"] = 0.2
output_no_upsampling = pipe(**inputs).images[0]
inputs["upsampling_strength"] = 0.8
output_full_upsampling = pipe(**inputs).images[0]
# Different upsampling strengths should produce different outputs
max_diff = np.abs(output_no_upsampling - output_full_upsampling).max()
assert max_diff > expected_min_diff
def test_different_task_prompts(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_original = pipe(**inputs).images[0]
inputs["task_prompt"] = "A different task description for image generation"
output_different_task = pipe(**inputs).images[0]
# Different task prompts should produce different outputs
max_diff = np.abs(output_original - output_different_task).max()
assert max_diff > expected_min_diff
@unittest.skip(
"Test not applicable because the pipeline being tested is a wrapper pipeline. CFG tests should be done on the inner pipelines."
)
def test_callback_cfg(self):
pass
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)

View File

@@ -0,0 +1,312 @@
import random
import tempfile
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
import diffusers
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxTransformer2DModel,
VisualClozeGenerationPipeline,
)
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
floats_tensor,
require_accelerator,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = VisualClozeGenerationPipeline
params = frozenset(
[
"task_prompt",
"content_prompt",
"guidance_scale",
"prompt_embeds",
"pooled_prompt_embeds",
]
)
batch_params = frozenset(["task_prompt", "content_prompt", "image"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=12,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=6,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[2, 2, 2],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"resolution": 32,
}
def get_dummy_inputs(self, device, seed=0):
# Create example images to simulate the input format required by VisualCloze
context_image = [
Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
for _ in range(2)
]
query_image = [
Image.fromarray(
floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
),
None,
]
# Create an image list that conforms to the VisualCloze input format
image = [
context_image, # In-Context example
query_image, # Query image
]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
"content_prompt": "A beautiful landscape with mountains and a lake",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"max_sequence_length": 77,
"output_type": "np",
}
return inputs
def test_visualcloze_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["task_prompt"] = "A different task to perform."
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different
assert max_diff > 1e-6
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
def test_different_task_prompts(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_original = pipe(**inputs).images[0]
inputs["task_prompt"] = "A different task description for image generation"
output_different_task = pipe(**inputs).images[0]
# Different task prompts should produce different outputs
max_diff = np.abs(output_original - output_different_task).max()
assert max_diff > expected_min_diff
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
pass